prelu_op_plugin.cu 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
16

17
#include <cassert>
N
nhzlx 已提交
18
#include <vector>
19

20 21
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
N
nhzlx 已提交
22
#include "paddle/fluid/operators/math/prelu.h"
23 24 25 26

namespace paddle {
namespace inference {
namespace tensorrt {
27
namespace plugin {
28

29
int PReluPlugin::initialize() TRT_NOEXCEPT {
N
nhzlx 已提交
30 31 32
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
33
  return 0;
N
nhzlx 已提交
34 35
}

36
void PReluPlugin::terminate() TRT_NOEXCEPT {
37 38 39 40 41 42
  if (p_gpu_weight_) {
    cudaFree(p_gpu_weight_);
    p_gpu_weight_ = nullptr;
  }
}

43 44
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
                                                const nvinfer1::Dims *inputDims,
45
                                                int nbInputs) TRT_NOEXCEPT {
46 47 48 49 50 51 52
  assert(nbInputs == 1);
  assert(index < this->getNbOutputs());
  nvinfer1::Dims const &input_dims = inputDims[0];
  nvinfer1::Dims output_dims = input_dims;
  return output_dims;
}

N
nhzlx 已提交
53
int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
54
#if IS_TRT_VERSION_LT(8000)
55
                         void **outputs, void *workspace, cudaStream_t stream) {
56 57
#else
                         void *const *outputs, void *workspace,
58
                         cudaStream_t stream) TRT_NOEXCEPT {
59
#endif
60 61 62
  // input dims is CHW.
  const auto &input_dims = this->getInputDims(0);
  const float *input = reinterpret_cast<const float *>(inputs[0]);
N
nhzlx 已提交
63 64
  // const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
  const float *alpha = p_gpu_weight_;
65
  float *const output = reinterpret_cast<float *const *>(outputs)[0];
66
  int numel = 1;
N
nhzlx 已提交
67
  for (int i = 0; i < input_dims.nbDims; i++) {
68
    numel *= input_dims.d[i];
N
nhzlx 已提交
69 70
  }

71
  if (mode_ == "channel") {
72
    bool channel_last = data_format_ == "NHWC";
N
nhzlx 已提交
73 74
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
75
    prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
76
                       input_dims.d[1], channel_last, numel);
77
  } else if (mode_ == "element") {
N
nhzlx 已提交
78 79
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
80
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
81
  } else {
N
nhzlx 已提交
82
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
83
    prelu_scalar(stream, input, alpha, output, numel);
84 85 86 87
  }
  return cudaGetLastError() != cudaSuccess;
}

88 89
#if IS_TRT_VERSION_GE(6000)

90
void PReluPluginDynamic::terminate() TRT_NOEXCEPT {
91 92 93 94 95
  if (p_gpu_weight_) {
    cudaFree(p_gpu_weight_);
  }
}

96
int PReluPluginDynamic::initialize() TRT_NOEXCEPT {
97 98 99 100 101 102
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
  return 0;
}

103 104 105 106 107 108 109 110
PReluPluginDynamic::PReluPluginDynamic(void const *serialData,
                                       size_t serialLength) {
  DeserializeValue(&serialData, &serialLength, &weight_);
  const char *prelu_mode;
  DeserializeValue(&serialData, &serialLength, &prelu_mode);
  mode_ = std::string(prelu_mode);
}

111
size_t PReluPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
112 113 114
  return SerializedSize(mode_.c_str()) + SerializedSize(weight_);
}

115
void PReluPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
116 117 118
  SerializeValue(&buffer, weight_);
  SerializeValue(&buffer, mode_.c_str());
}
119 120 121

nvinfer1::DimsExprs PReluPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
122
    nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
123 124 125 126 127
  return inputs[0];
}

bool PReluPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
128
    int nb_outputs) TRT_NOEXCEPT {
129 130 131 132 133 134 135 136 137 138 139 140
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
141
          in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
142 143 144
}

nvinfer1::DataType PReluPluginDynamic::getOutputDataType(
145 146
    int index, const nvinfer1::DataType *input_types,
    int nb_inputs) const TRT_NOEXCEPT {
147 148 149 150 151 152 153 154 155 156 157 158 159
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The PRelu Plugin only has one input, so the "
                                  "index value should be 0, but get %d.",
                                  index));
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true,
                    platform::errors::InvalidArgument(
                        "The input type should be half or float"));
  return input_types[0];
}

int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                                const nvinfer1::PluginTensorDesc *output_desc,
                                const void *const *inputs, void *const *outputs,
160 161
                                void *workspace,
                                cudaStream_t stream) TRT_NOEXCEPT {
162 163 164 165
  auto input_dims = input_desc[0].dims;
  const float *alpha = p_gpu_weight_;
  const float *input = static_cast<const float *>(inputs[0]);
  float *output = static_cast<float *>(outputs[0]);
166
  int numel = 1;
167
  for (int i = 0; i < input_dims.nbDims; i++) {
168
    numel *= input_dims.d[i];
169 170 171
  }

  if (mode_ == "channel") {
172
    bool channel_last = data_format_ == "NHWC";
173 174
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
175
    prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
176
                       input_dims.d[1], channel_last, numel);
177 178 179
  } else if (mode_ == "element") {
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
180
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
181 182
  } else {
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
183
    prelu_scalar(stream, input, alpha, output, numel);
184 185 186 187 188
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

189
}  // namespace plugin
190 191 192
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle