prelu_op_plugin.h 7.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
nhzlx 已提交
17
#include <algorithm>
18
#include <string>
N
nhzlx 已提交
19
#include <vector>
20

N
nhzlx 已提交
21 22
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
23 24 25 26 27 28
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
29
namespace plugin {
30 31

class PReluPlugin : public PluginTensorRT {
N
nhzlx 已提交
32
  std::vector<float> weight_;
33
  float* p_gpu_weight_;
34
  std::string mode_;
35
  std::string data_format_;
36

37
 public:
38
  size_t getSerializationSize() const TRT_NOEXCEPT override {
N
nhzlx 已提交
39
    return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
40
           SerializedSize(data_format_.c_str()) + SerializedSize(weight_);
41 42 43 44 45
  }

  // TRT will call this func when we need to serialize the configuration of
  // tensorrt.
  // It should not be called by users.
46
  void serialize(void* buffer) const TRT_NOEXCEPT override {
N
nhzlx 已提交
47 48 49
    serializeBase(buffer);
    SerializeValue(&buffer, weight_);
    SerializeValue(&buffer, mode_.c_str());
50
    SerializeValue(&buffer, data_format_.c_str());
51 52
  }

53 54 55 56
  PReluPlugin(const float* weight,
              const int weight_num,
              std::string const& mode,
              std::string const& data_format)
57
      : mode_(mode), data_format_(data_format) {
N
nhzlx 已提交
58 59 60
    weight_.resize(weight_num);
    std::copy(weight, weight + weight_num, weight_.data());
  }
61 62 63

  // It was used for tensorrt deserialization.
  // It should not be called by users.
64
  PReluPlugin(void const* serialData, size_t serialLength) {
N
nhzlx 已提交
65 66
    deserializeBase(serialData, serialLength);
    DeserializeValue(&serialData, &serialLength, &weight_);
67
    const char* prelu_mode;
N
nhzlx 已提交
68 69
    DeserializeValue(&serialData, &serialLength, &prelu_mode);
    mode_ = std::string(prelu_mode);
70 71 72
    const char* prelu_data_format;
    DeserializeValue(&serialData, &serialLength, &prelu_data_format);
    data_format_ = std::string(prelu_data_format);
73
  }
74
  ~PReluPlugin() {}
75 76
  int initialize() TRT_NOEXCEPT override;
  void terminate() TRT_NOEXCEPT override;
77

78
  PReluPlugin* clone() const TRT_NOEXCEPT override {
79 80
    auto* ptr =
        new PReluPlugin(weight_.data(), weight_.size(), mode_, data_format_);
81 82
    ptr->p_gpu_weight_ = p_gpu_weight_;
    return ptr;
N
nhzlx 已提交
83
  }
84

85 86 87 88
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "prelu_plugin";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
89 90
  nvinfer1::Dims getOutputDimensions(int index,
                                     const nvinfer1::Dims* inputs,
91
                                     int nbInputDims) TRT_NOEXCEPT override;
92
#if IS_TRT_VERSION_LT(8000)
93 94 95
  int enqueue(int batchSize,
              const void* const* inputs,
              void** outputs,
96
#else
97 98 99
  int enqueue(int batchSize,
              const void* const* inputs,
              void* const* outputs,
100
#endif
101 102
              void* workspace,
              cudaStream_t stream) TRT_NOEXCEPT override;
103 104
};

105 106
class PReluPluginCreator : public TensorRTPluginCreator {
 public:
107 108 109
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "prelu_plugin";
  }
110

111
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
112

113 114 115 116
  nvinfer1::IPluginV2* deserializePlugin(const char* name,
                                         const void* serial_data,
                                         size_t serial_length)
      TRT_NOEXCEPT override {
117 118 119 120 121
    return new PReluPlugin(serial_data, serial_length);
  }
};
REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);

122 123 124
#if IS_TRT_VERSION_GE(6000)
class PReluPluginDynamic : public DynamicPluginTensorRT {
 public:
125 126 127 128
  PReluPluginDynamic(const float* weight,
                     const int weight_num,
                     std::string const& mode,
                     std::string const& data_format)
129
      : mode_(mode), data_format_(data_format) {
130 131 132 133
    weight_.resize(weight_num);
    std::copy(weight, weight + weight_num, weight_.data());
  }

134
  PReluPluginDynamic(void const* serialData, size_t serialLength);
135
  ~PReluPluginDynamic() {}
136
  nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
137 138
    auto ptr = new PReluPluginDynamic(
        weight_.data(), weight_.size(), mode_, data_format_);
139 140
    ptr->p_gpu_weight_ = p_gpu_weight_;
    return ptr;
141 142
  }

143 144 145 146 147 148
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "prelu_plugin_dynamic";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
  int initialize() TRT_NOEXCEPT override;
  void terminate() TRT_NOEXCEPT override;
149

150 151
  size_t getSerializationSize() const TRT_NOEXCEPT override;
  void serialize(void* buffer) const TRT_NOEXCEPT override;
152

153 154 155 156 157
  nvinfer1::DimsExprs getOutputDimensions(int output_index,
                                          const nvinfer1::DimsExprs* inputs,
                                          int nb_inputs,
                                          nvinfer1::IExprBuilder& expr_builder)
      TRT_NOEXCEPT override;
158 159 160

  bool supportsFormatCombination(int pos,
                                 const nvinfer1::PluginTensorDesc* inOut,
161 162
                                 int nbInputs,
                                 int nbOutputs) TRT_NOEXCEPT override;
163 164 165 166

  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
                       int nbInputs,
                       const nvinfer1::DynamicPluginTensorDesc* out,
167
                       int nbOutputs) TRT_NOEXCEPT override {}
168 169 170 171

  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
                          int nbInputs,
                          const nvinfer1::PluginTensorDesc* outputs,
172
                          int nbOutputs) const TRT_NOEXCEPT override {
173 174 175 176 177
    return 0;
  }

  int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
              const nvinfer1::PluginTensorDesc* outputDesc,
178 179 180
              const void* const* inputs,
              void* const* outputs,
              void* workspace,
181
              cudaStream_t stream) TRT_NOEXCEPT override;
182 183 184 185
  nvinfer1::DataType getOutputDataType(int index,
                                       const nvinfer1::DataType* inputTypes,
                                       int nbInputs) const
      TRT_NOEXCEPT override;
186

187
  void destroy() TRT_NOEXCEPT override { delete this; }
188 189 190 191 192

 private:
  std::vector<float> weight_;
  float* p_gpu_weight_;
  std::string mode_;
193
  std::string data_format_;
194 195 196
};
#endif

197 198
class PReluPluginDynamicCreator : public TensorRTPluginCreator {
 public:
199 200 201
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "prelu_plugin_dynamic";
  }
202

203
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
204

205 206 207 208
  nvinfer1::IPluginV2* deserializePlugin(const char* name,
                                         const void* serial_data,
                                         size_t serial_length)
      TRT_NOEXCEPT override {
209 210 211 212 213
    return new PReluPluginDynamic(serial_data, serial_length);
  }
};
REGISTER_TRT_PLUGIN_V2(PReluPluginDynamicCreator);

214
}  // namespace plugin
215 216 217
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle