split_op_plugin.h 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
N
nhzlx 已提交
14 15 16

#pragma once

17
#include <string>
N
nhzlx 已提交
18
#include <utility>
19
#include <vector>
20
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
N
nhzlx 已提交
21 22 23 24

namespace paddle {
namespace inference {
namespace tensorrt {
25
namespace plugin {
N
nhzlx 已提交
26

27
class SplitPlugin : public PluginTensorRTV2Ext {
28
 public:
N
nhzlx 已提交
29
  SplitPlugin() {}
30 31 32 33
  SplitPlugin(int axis, std::vector<int> const& output_lengths, bool with_fp16)
      : axis_(axis), same_shape_(true), output_length_(output_lengths) {
    with_fp16_ = with_fp16;
  }
34

35
  SplitPlugin(void const* serial_data, size_t serial_length) {
36 37 38 39 40
    deserializeBase(serial_data, serial_length);
    DeserializeValue(&serial_data, &serial_length, &axis_);
    DeserializeValue(&serial_data, &serial_length, &output_length_);
  }

41
  nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override {
42 43
    SplitPlugin* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
    ptr->setPluginNamespace(this->getPluginNamespace());
44 45
    ptr->shareData(this);
    return ptr;
46 47
  }

48 49 50
  nvinfer1::DataType getOutputDataType(
      int index, const nvinfer1::DataType* input_types,
      int nb_inputs) const TRT_NOEXCEPT override {
51 52 53
    return input_types[0];
  }

54 55 56 57 58 59
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "split_plugin_v2ext";
  }
  int getNbOutputs() const TRT_NOEXCEPT override {
    return output_length_.size();
  }
60
  nvinfer1::Dims getOutputDimensions(int index,
61
                                     const nvinfer1::Dims* input_dims,
62
                                     int num_inputs) TRT_NOEXCEPT override;
63

64 65
  int initialize() TRT_NOEXCEPT override;
  void terminate() TRT_NOEXCEPT override;
66
#if IS_TRT_VERSION_LT(8000)
67
  int enqueue(int batch_size, const void* const* inputs, void** outputs,
68 69 70
#else
  int enqueue(int batch_size, const void* const* inputs, void* const* outputs,
#endif
71
              void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
N
nhzlx 已提交
72

73
  void destroy() TRT_NOEXCEPT override { delete this; }
74

N
nhzlx 已提交
75
 protected:
76
  size_t getSerializationSize() const TRT_NOEXCEPT override {
77 78
    return SerializedSize(axis_) + SerializedSize(output_length_) +
           getBaseSerializationSize();
N
nhzlx 已提交
79 80
  }

81
  void serialize(void* buffer) const TRT_NOEXCEPT override {
N
nhzlx 已提交
82
    serializeBase(buffer);
N
nhzlx 已提交
83 84
    SerializeValue(&buffer, axis_);
    SerializeValue(&buffer, output_length_);
N
nhzlx 已提交
85 86
  }

87
  int axis_;
H
hjchen2 已提交
88 89
  int outer_rows_;
  int inner_cols_;
90
  int axis_shape_;
H
hjchen2 已提交
91
  bool same_shape_;
92 93
  std::vector<int> output_length_;
  std::vector<int> segment_offsets_;
94 95 96

 private:
  void shareData(const SplitPlugin* another);
N
nhzlx 已提交
97 98
};

99 100 101
class SplitPluginCreator : public nvinfer1::IPluginCreator {
 public:
  SplitPluginCreator() {}
102 103 104
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "split_plugin_v2ext";
  }
105

106
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
107

108
  const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
109 110 111
    return &field_collection_;
  }

112 113 114
  nvinfer1::IPluginV2* createPlugin(const char* name,
                                    const nvinfer1::PluginFieldCollection* fc)
      TRT_NOEXCEPT override {
115 116 117 118
    // not implemented
    return nullptr;
  }

119 120 121
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serial_data,
      size_t serial_length) TRT_NOEXCEPT override {
122 123 124 125
    auto plugin = new SplitPlugin(serial_data, serial_length);
    return plugin;
  }

126
  void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
127 128 129
    plugin_namespace_ = lib_namespace;
  }

130
  const char* getPluginNamespace() const TRT_NOEXCEPT override {
131 132 133 134 135 136 137 138 139 140 141 142
    return plugin_namespace_.c_str();
  }

 private:
  std::string plugin_namespace_;
  std::string plugin_name_;
  nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
  std::vector<nvinfer1::PluginField> plugin_attributes_;
};

REGISTER_TRT_PLUGIN_V2(SplitPluginCreator);

143 144 145
#if IS_TRT_VERSION_GE(6000)
class SplitPluginDynamic : public DynamicPluginTensorRT {
 public:
146 147 148 149 150
  SplitPluginDynamic(int axis, std::vector<int> const& output_lengths,
                     bool with_fp16)
      : axis_(axis), output_length_(output_lengths) {
    with_fp16_ = with_fp16;
  }
151

152 153 154 155 156
  SplitPluginDynamic(void const* serial_data, size_t serial_length) {
    DeserializeValue(&serial_data, &serial_length, &axis_);
    DeserializeValue(&serial_data, &serial_length, &output_length_);
    DeserializeValue(&serial_data, &serial_length, &with_fp16_);
  }
157

158
  nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
159
    return new SplitPluginDynamic(axis_, output_length_, with_fp16_);
160 161
  }

162 163 164 165 166 167 168
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "split_plugin";
  }
  int getNbOutputs() const TRT_NOEXCEPT override {
    return output_length_.size();
  }
  int initialize() TRT_NOEXCEPT override;
169

170 171
  size_t getSerializationSize() const TRT_NOEXCEPT override;
  void serialize(void* buffer) const TRT_NOEXCEPT override;
172 173 174

  nvinfer1::DimsExprs getOutputDimensions(
      int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
175
      nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
176 177 178

  bool supportsFormatCombination(int pos,
                                 const nvinfer1::PluginTensorDesc* inOut,
179 180
                                 int nbInputs,
                                 int nbOutputs) TRT_NOEXCEPT override;
181 182 183 184

  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
                       int nbInputs,
                       const nvinfer1::DynamicPluginTensorDesc* out,
185
                       int nbOutputs) TRT_NOEXCEPT override {}
186 187 188 189

  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
                          int nbInputs,
                          const nvinfer1::PluginTensorDesc* outputs,
190
                          int nbOutputs) const TRT_NOEXCEPT override {
191 192 193 194 195 196
    return 0;
  }

  int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
              const nvinfer1::PluginTensorDesc* outputDesc,
              const void* const* inputs, void* const* outputs, void* workspace,
197 198 199 200
              cudaStream_t stream) TRT_NOEXCEPT override;
  nvinfer1::DataType getOutputDataType(
      int index, const nvinfer1::DataType* inputTypes,
      int nbInputs) const TRT_NOEXCEPT override;
201

202
  void destroy() TRT_NOEXCEPT override { delete this; }
203 204 205 206 207

 private:
  int axis_;
  std::vector<int> output_length_;
};
208

W
Wilber 已提交
209
class SplitPluginDynamicCreator : public nvinfer1::IPluginCreator {
210
 public:
W
Wilber 已提交
211
  SplitPluginDynamicCreator() {}
212 213 214
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "split_plugin";
  }
215

216
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
217

218
  const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
219 220 221
    return &field_collection_;
  }

222 223 224
  nvinfer1::IPluginV2* createPlugin(const char* name,
                                    const nvinfer1::PluginFieldCollection* fc)
      TRT_NOEXCEPT override {
225 226 227
    return nullptr;
  }

228 229 230
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serial_data,
      size_t serial_length) TRT_NOEXCEPT override {
231 232 233 234
    auto plugin = new SplitPluginDynamic(serial_data, serial_length);
    return plugin;
  }

235
  void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
236 237 238
    plugin_namespace_ = lib_namespace;
  }

239
  const char* getPluginNamespace() const TRT_NOEXCEPT override {
240 241 242 243 244 245 246 247 248 249
    return plugin_namespace_.c_str();
  }

 private:
  std::string plugin_namespace_;
  std::string plugin_name_;
  nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
  std::vector<nvinfer1::PluginField> plugin_attributes_;
};

W
Wilber 已提交
250
REGISTER_TRT_PLUGIN_V2(SplitPluginDynamicCreator);
251 252
#endif

253 254 255 256
}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle