prelu_op_plugin.h 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
nhzlx 已提交
17
#include <algorithm>
18
#include <string>
N
nhzlx 已提交
19 20 21 22
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"

23 24 25 26 27 28
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
29
namespace plugin {
30 31

class PReluPlugin : public PluginTensorRT {
N
nhzlx 已提交
32
  std::vector<float> weight_;
33
  float* p_gpu_weight_;
34 35
  std::string mode_;

36
 public:
37
  size_t getSerializationSize() const TRT_NOEXCEPT override {
N
nhzlx 已提交
38
    return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
39
           SerializedSize(weight_);
40 41 42 43 44
  }

  // TRT will call this func when we need to serialize the configuration of
  // tensorrt.
  // It should not be called by users.
45
  void serialize(void* buffer) const TRT_NOEXCEPT override {
N
nhzlx 已提交
46 47 48
    serializeBase(buffer);
    SerializeValue(&buffer, weight_);
    SerializeValue(&buffer, mode_.c_str());
49 50
  }

51 52
  PReluPlugin(const float* weight, const int weight_num,
              std::string const& mode)
N
nhzlx 已提交
53 54 55 56
      : mode_(mode) {
    weight_.resize(weight_num);
    std::copy(weight, weight + weight_num, weight_.data());
  }
57 58 59

  // It was used for tensorrt deserialization.
  // It should not be called by users.
60
  PReluPlugin(void const* serialData, size_t serialLength) {
N
nhzlx 已提交
61 62
    deserializeBase(serialData, serialLength);
    DeserializeValue(&serialData, &serialLength, &weight_);
63
    const char* prelu_mode;
N
nhzlx 已提交
64 65
    DeserializeValue(&serialData, &serialLength, &prelu_mode);
    mode_ = std::string(prelu_mode);
66
  }
67
  ~PReluPlugin() {}
68 69
  int initialize() TRT_NOEXCEPT override;
  void terminate() TRT_NOEXCEPT override;
70

71
  PReluPlugin* clone() const TRT_NOEXCEPT override {
72 73 74
    auto* ptr = new PReluPlugin(weight_.data(), weight_.size(), mode_);
    ptr->p_gpu_weight_ = p_gpu_weight_;
    return ptr;
N
nhzlx 已提交
75
  }
76

77 78 79 80
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "prelu_plugin";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
81
  nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
82
                                     int nbInputDims) TRT_NOEXCEPT override;
83
#if IS_TRT_VERSION_LT(8000)
84
  int enqueue(int batchSize, const void* const* inputs, void** outputs,
85 86 87
#else
  int enqueue(int batchSize, const void* const* inputs, void* const* outputs,
#endif
88
              void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
89 90
};

91 92
class PReluPluginCreator : public TensorRTPluginCreator {
 public:
93 94 95
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "prelu_plugin";
  }
96

97
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
98

99 100 101
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serial_data,
      size_t serial_length) TRT_NOEXCEPT override {
102 103 104 105 106
    return new PReluPlugin(serial_data, serial_length);
  }
};
REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);

107 108 109 110 111 112 113 114 115 116
#if IS_TRT_VERSION_GE(6000)
class PReluPluginDynamic : public DynamicPluginTensorRT {
 public:
  PReluPluginDynamic(const float* weight, const int weight_num,
                     std::string const& mode)
      : mode_(mode) {
    weight_.resize(weight_num);
    std::copy(weight, weight + weight_num, weight_.data());
  }

117
  PReluPluginDynamic(void const* serialData, size_t serialLength);
118
  ~PReluPluginDynamic() {}
119
  nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
120 121 122
    auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_);
    ptr->p_gpu_weight_ = p_gpu_weight_;
    return ptr;
123 124
  }

125 126 127 128 129 130
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "prelu_plugin_dynamic";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
  int initialize() TRT_NOEXCEPT override;
  void terminate() TRT_NOEXCEPT override;
131

132 133
  size_t getSerializationSize() const TRT_NOEXCEPT override;
  void serialize(void* buffer) const TRT_NOEXCEPT override;
134 135 136

  nvinfer1::DimsExprs getOutputDimensions(
      int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
137
      nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
138 139 140

  bool supportsFormatCombination(int pos,
                                 const nvinfer1::PluginTensorDesc* inOut,
141 142
                                 int nbInputs,
                                 int nbOutputs) TRT_NOEXCEPT override;
143 144 145 146

  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
                       int nbInputs,
                       const nvinfer1::DynamicPluginTensorDesc* out,
147
                       int nbOutputs) TRT_NOEXCEPT override {}
148 149 150 151

  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
                          int nbInputs,
                          const nvinfer1::PluginTensorDesc* outputs,
152
                          int nbOutputs) const TRT_NOEXCEPT override {
153 154 155 156 157 158
    return 0;
  }

  int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
              const nvinfer1::PluginTensorDesc* outputDesc,
              const void* const* inputs, void* const* outputs, void* workspace,
159 160 161 162
              cudaStream_t stream) TRT_NOEXCEPT override;
  nvinfer1::DataType getOutputDataType(
      int index, const nvinfer1::DataType* inputTypes,
      int nbInputs) const TRT_NOEXCEPT override;
163

164
  void destroy() TRT_NOEXCEPT override { delete this; }
165 166 167 168 169 170 171 172

 private:
  std::vector<float> weight_;
  float* p_gpu_weight_;
  std::string mode_;
};
#endif

173 174
class PReluPluginDynamicCreator : public TensorRTPluginCreator {
 public:
175 176 177
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "prelu_plugin_dynamic";
  }
178

179
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
180

181 182 183
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serial_data,
      size_t serial_length) TRT_NOEXCEPT override {
184 185 186 187 188
    return new PReluPluginDynamic(serial_data, serial_length);
  }
};
REGISTER_TRT_PLUGIN_V2(PReluPluginDynamicCreator);

189
}  // namespace plugin
190 191 192
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle