trt_plugin.h 14.3 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <NvInfer.h>
18

N
nhzlx 已提交
19
#include <cstring>
20
#include <string>
N
nhzlx 已提交
21
#include <unordered_map>
N
nhzlx 已提交
22
#include <utility>
N
nhzlx 已提交
23 24
#include <vector>

25
#include "paddle/fluid/inference/tensorrt/helper.h"
N
nhzlx 已提交
26
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
27
#include "paddle/fluid/platform/enforce.h"
28
#include "paddle/fluid/platform/profiler/event_tracing.h"
29

30 31 32 33
namespace nvinfer1 {
class ITensor;
}  // namespace nvinfer1

34
DECLARE_bool(profile);
N
nhzlx 已提交
35 36 37 38

namespace paddle {
namespace inference {
namespace tensorrt {
39
namespace plugin {
N
nhzlx 已提交
40

41 42 43 44 45 46 47
#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif

N
nhzlx 已提交
48 49 50 51 52 53 54
class PluginTensorRT;

typedef std::function<PluginTensorRT*(const void*, size_t)>
    PluginDeserializeFunc;

typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;

55
// Deprecated. Do not inherit this class, please refer to PluginTensorRTV2Ext
56
class PluginTensorRT : public nvinfer1::IPluginV2 {
N
nhzlx 已提交
57
 public:
58
  PluginTensorRT() : with_fp16_(false) {}
59

60 61
  // It was used for TensorRT deserialization.
  // It should not be called by users.
N
nhzlx 已提交
62
  PluginTensorRT(const void* serialized_data, size_t length) {}
63

64 65
  virtual ~PluginTensorRT() {}

N
nhzlx 已提交
66 67 68
  nvinfer1::Dims const& getInputDims(int index) const {
    return input_dims_.at(index);
  }
69

N
nhzlx 已提交
70
  nvinfer1::DataType getDataType() const { return data_type_; }
71

72
  nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
73

74
  // IPluginV2
75
  virtual const char* getPluginType() const TRT_NOEXCEPT = 0;
76

77
  virtual const char* getPluginVersion() const TRT_NOEXCEPT { return "1"; }
78

79
  int getNbOutputs() const TRT_NOEXCEPT { return 1; }
80

81 82
  virtual nvinfer1::Dims getOutputDimensions(int index,
                                             const nvinfer1::Dims* input_dims,
83
                                             int num_inputs) TRT_NOEXCEPT = 0;
84 85

  // Check format support. The default is FLOAT32 and kLINEAR.
86 87
  bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
      const TRT_NOEXCEPT override;
88 89

  // Configure the layer
90 91 92 93
  void configureWithFormat(const nvinfer1::Dims* input_dims,
                           int num_inputs,
                           const nvinfer1::Dims* output_dims,
                           int num_outputs,
94 95
                           nvinfer1::DataType type,
                           nvinfer1::PluginFormat format,
96
                           int max_batch_size) TRT_NOEXCEPT override;
97 98

  // Initialize the layer for execution.
99
  int initialize() TRT_NOEXCEPT override { return 0; }
100

101
  // Shutdown the layer. This is called when the engine is destroyed
102
  void terminate() TRT_NOEXCEPT override {}
103 104

  // Find the workspace size required by the layer
105
  size_t getWorkspaceSize(int) const TRT_NOEXCEPT override { return 0; }
106

107 108
// Execute the layer
#if IS_TRT_VERSION_LT(8000)
109 110 111
  virtual int enqueue(int batch_size,
                      const void* const* inputs,
                      void** outputs,
112
#else
113 114
  virtual int enqueue(int batch_size,
                      const void* const* inputs,
115 116
                      void* const* outputs,
#endif
117 118
                      void* workspace,
                      cudaStream_t stream) TRT_NOEXCEPT = 0;
119 120

  // Find the size of the serialization buffer required
121
  virtual size_t getSerializationSize() const TRT_NOEXCEPT = 0;
122

123 124 125
  // Serialize the layer config to buffer.
  // TensorRT will call this func to serialize the configuration of TensorRT
  // engine. It should not be called by users.
126
  virtual void serialize(void* buffer) const TRT_NOEXCEPT = 0;
127

128
  void destroy() TRT_NOEXCEPT override { delete this; }
129

130
  virtual nvinfer1::IPluginV2* clone() const TRT_NOEXCEPT = 0;
131

132
  void setPluginNamespace(const char* plugin_namespace) TRT_NOEXCEPT override {
133 134 135
    namespace_ = plugin_namespace;
  }

136 137 138
  const char* getPluginNamespace() const TRT_NOEXCEPT override {
    return namespace_.c_str();
  }
N
nhzlx 已提交
139 140

 protected:
N
nhzlx 已提交
141
  // Deserialize input_dims, max_batch_size, data_type, data_format
142 143
  void deserializeBase(void const*& serial_data,  // NOLINT
                       size_t& serial_length);    // NOLINT
144
  size_t getBaseSerializationSize() const;
N
nhzlx 已提交
145
  // Serialize input_dims, max_batch_size, data_type, data_format
146
  void serializeBase(void*& buffer) const;  // NOLINT
N
nhzlx 已提交
147 148 149 150

  std::vector<nvinfer1::Dims> input_dims_;
  nvinfer1::DataType data_type_;
  nvinfer1::PluginFormat data_format_;
151

152
  bool with_fp16_;
153 154 155

 private:
  std::string namespace_;
N
nhzlx 已提交
156 157
};

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
// TensorRT introduced IPluginV2Ext after 5.1, Paddle no longer supports
// versions before 5.1
class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
 public:
  PluginTensorRTV2Ext() : with_fp16_(false) {}
  PluginTensorRTV2Ext(const void* serialized_data, size_t length) {}

  nvinfer1::Dims const& getInputDims(int index) const {
    return input_dims_.at(index);
  }
  nvinfer1::DataType getDataType() const { return data_type_; }
  nvinfer1::PluginFormat getDataFormat() const { return data_format_; }

  // The Func in IPluginV2Ext
  virtual nvinfer1::DataType getOutputDataType(
173 174
      int index,
      const nvinfer1::DataType* input_types,
175
      int nb_inputs) const TRT_NOEXCEPT = 0;
176

177 178 179 180
  virtual bool isOutputBroadcastAcrossBatch(int32_t output_index,
                                            const bool* input_is_broadcasted,
                                            int32_t nb_inputs) const
      TRT_NOEXCEPT {
181 182 183
    return false;
  }

184 185
  virtual bool canBroadcastInputAcrossBatch(int32_t input_index) const
      TRT_NOEXCEPT {
186 187 188
    return false;
  }

189 190 191 192
  void configurePlugin(const nvinfer1::Dims* input_dims,
                       int32_t nb_inputs,
                       const nvinfer1::Dims* output_dims,
                       int32_t nb_outputs,
193 194 195 196 197
                       const nvinfer1::DataType* input_types,
                       const nvinfer1::DataType* output_types,
                       const bool* input_is_broadcast,
                       const bool* output_is_broadcast,
                       nvinfer1::PluginFormat float_format,
198
                       int32_t max_batch_size) TRT_NOEXCEPT override;
199

200
  virtual IPluginV2Ext* clone() const TRT_NOEXCEPT = 0;
201

202 203
  void attachToContext(cudnnContext*,
                       cublasContext*,
204
                       nvinfer1::IGpuAllocator*) TRT_NOEXCEPT override {}
205

206
  void detachFromContext() TRT_NOEXCEPT override {}
207 208

  // The Func in IPluginV2
209 210 211
  virtual const char* getPluginType() const TRT_NOEXCEPT = 0;
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
  virtual int32_t getNbOutputs() const TRT_NOEXCEPT { return 1; }
212 213
  virtual nvinfer1::Dims getOutputDimensions(int32_t index,
                                             const nvinfer1::Dims* inputs,
214
                                             int32_t nb_input) TRT_NOEXCEPT = 0;
215
  // Check format support. The default is FLOAT32 and NCHW.
216 217
  bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
      const TRT_NOEXCEPT override {
218
    return ((type == nvinfer1::DataType::kFLOAT) &&
219
            (format == nvinfer1::PluginFormat::kLINEAR));
220 221 222
  }
  // Initialize the layer for execution.
  // This is called when the engine is created.
223
  int initialize() TRT_NOEXCEPT override { return 0; }
224 225

  // Shutdown the layer. This is called when the engine is destroyed
226
  void terminate() TRT_NOEXCEPT override {}
227 228

  // Find the workspace size required by the layer
229
  size_t getWorkspaceSize(int) const TRT_NOEXCEPT override { return 0; }
230

231 232
// Execute the layer
#if IS_TRT_VERSION_LT(8000)
233 234 235
  virtual int enqueue(int batch_size,
                      const void* const* inputs,
                      void** outputs,
236
#else
237 238
  virtual int enqueue(int batch_size,
                      const void* const* inputs,
239 240
                      void* const* outputs,
#endif
241 242
                      void* workspace,
                      cudaStream_t stream) TRT_NOEXCEPT = 0;
243 244

  // Find the size of the serialization buffer required
245
  virtual size_t getSerializationSize() const TRT_NOEXCEPT = 0;
246 247 248 249

  // Serialize the layer config to buffer.
  // TensorRT will call this func to serialize the configuration of TensorRT
  // engine. It should not be called by users.
250
  virtual void serialize(void* buffer) const TRT_NOEXCEPT = 0;
251

252
  virtual void destroy() TRT_NOEXCEPT = 0;
253

254
  void setPluginNamespace(const char* plugin_namespace) TRT_NOEXCEPT override {
255 256 257
    name_space_ = plugin_namespace;
  }

258
  const char* getPluginNamespace() const TRT_NOEXCEPT override {
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    return name_space_.c_str();
  }

 protected:
  void deserializeBase(void const*& serial_data,  // NOLINT
                       size_t& serial_length);    // NOLINT
  size_t getBaseSerializationSize() const;
  void serializeBase(void*& buffer) const;  // NOLINT

 protected:
  std::vector<nvinfer1::Dims> input_dims_;
  nvinfer1::DataType data_type_;
  nvinfer1::PluginFormat data_format_;
  bool with_fp16_;

 private:
  std::string name_space_;
};

278 279 280
#if IS_TRT_VERSION_GE(6000)
class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
 public:
281
  DynamicPluginTensorRT() : with_fp16_(false) {}
282 283 284
  DynamicPluginTensorRT(const void* serialized_data, size_t length) {}

  // The Func in IPluginExt or IpluginExtV2
285 286 287 288 289
  virtual const char* getPluginVersion() const TRT_NOEXCEPT { return "1"; }
  virtual const char* getPluginType() const TRT_NOEXCEPT = 0;
  int getNbOutputs() const TRT_NOEXCEPT { return 1; }
  int initialize() TRT_NOEXCEPT override { return 0; }
  void terminate() TRT_NOEXCEPT override{};
290

291 292
  virtual size_t getSerializationSize() const TRT_NOEXCEPT = 0;
  virtual void serialize(void* buffer) const TRT_NOEXCEPT = 0;
293 294

  // The Func in IPluginV2
295
  nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT = 0;
296
  virtual nvinfer1::DimsExprs getOutputDimensions(
297 298 299
      int output_index,
      const nvinfer1::DimsExprs* inputs,
      int nb_inputs,
300
      nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT = 0;  // NOLINT
301 302

  virtual bool supportsFormatCombination(
303 304 305
      int pos,
      const nvinfer1::PluginTensorDesc* in_out,
      int nb_inputs,
306
      int nb_outputs) TRT_NOEXCEPT = 0;
307 308 309 310

  virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
                               int nb_inputs,
                               const nvinfer1::DynamicPluginTensorDesc* out,
311
                               int nb_outputs) TRT_NOEXCEPT = 0;
312 313 314 315

  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
                          int nb_inputs,
                          const nvinfer1::PluginTensorDesc* outputs,
316
                          int nb_outputs) const TRT_NOEXCEPT override {
317 318 319 320 321
    return 0;
  }

  virtual int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
                      const nvinfer1::PluginTensorDesc* output_desc,
322 323 324 325
                      const void* const* inputs,
                      void* const* outputs,
                      void* workspace,
                      cudaStream_t stream) TRT_NOEXCEPT = 0;
326 327

  virtual nvinfer1::DataType getOutputDataType(
328 329
      int index,
      const nvinfer1::DataType* input_types,
330 331
      int nb_inputs) const TRT_NOEXCEPT = 0;
  void setPluginNamespace(const char* plugin_namespace) TRT_NOEXCEPT override {
332 333
    name_space_ = plugin_namespace;
  }
334
  const char* getPluginNamespace() const TRT_NOEXCEPT override {
335 336
    return name_space_.c_str();
  }
337
  virtual void destroy() TRT_NOEXCEPT = 0;
338 339 340 341 342 343

 protected:
  void deserializeBase(void const*& serial_data,  // NOLINT
                       size_t& serial_length);    // NOLINT
  size_t getBaseSerializationSize() const;
  void serializeBase(void*& buffer) const;  // NOLINT
344
  bool with_fp16_;
345 346

 private:
P
Pei Yang 已提交
347 348
  std::string name_space_;
  std::string plugin_base_;
349
};
350
#endif
351

352 353 354 355
class TensorRTPluginCreator : public nvinfer1::IPluginCreator {
 public:
  TensorRTPluginCreator() = default;

356
  virtual const char* getPluginName() const TRT_NOEXCEPT = 0;
357

358
  virtual const char* getPluginVersion() const TRT_NOEXCEPT = 0;
359

360
  const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
361

362 363 364
  nvinfer1::IPluginV2* createPlugin(const char* name,
                                    const nvinfer1::PluginFieldCollection* fc)
      TRT_NOEXCEPT override;
365

366 367 368 369
  virtual nvinfer1::IPluginV2* deserializePlugin(const char* name,
                                                 const void* serial_data,
                                                 size_t serial_length)
      TRT_NOEXCEPT = 0;
370

371
  void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override;
372

373
  const char* getPluginNamespace() const TRT_NOEXCEPT override;
374 375 376 377 378 379 380 381

 private:
  std::string plugin_namespace_;
  std::string plugin_name_;
  nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
  std::vector<nvinfer1::PluginField> plugin_attributes_;
};

382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
class TrtPluginRegistry {
 public:
  static TrtPluginRegistry* Global() {
    static TrtPluginRegistry registry;
    return &registry;
  }
  bool Regist(const std::string& name, const std::function<void()>& func) {
    map.emplace(name, func);
    return true;
  }
  void RegistToTrt() {
    for (auto& it : map) {
      it.second();
    }
  }

 private:
  std::unordered_map<std::string, std::function<void()>> map;
};

P
Pei Yang 已提交
402 403 404
template <typename T>
class TrtPluginRegistrarV2 {
 public:
405 406 407 408 409 410
  TrtPluginRegistrarV2() {
    static auto func_ptr = GetPluginRegistry();
    if (func_ptr != nullptr) {
      func_ptr->registerCreator(creator, "");
    }
  }
P
Pei Yang 已提交
411 412 413 414 415

 private:
  T creator;
};

416 417 418 419 420 421 422 423
#define REGISTER_TRT_PLUGIN_V2(name) REGISTER_TRT_PLUGIN_V2_HELPER(name)

#define REGISTER_TRT_PLUGIN_V2_HELPER(name)                                    \
  UNUSED static bool REGISTER_TRT_PLUGIN_V2_HELPER##name =                     \
      TrtPluginRegistry::Global()->Regist(#name, []() -> void {                \
        static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
            plugin_registrar_##name{};                                         \
      });
P
Pei Yang 已提交
424

425
}  // namespace plugin
N
nhzlx 已提交
426 427 428
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle