slice_op_plugin.h 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <string>
#include <vector>

#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

29 30 31
class SlicePlugin : public PluginTensorRT {
 public:
  explicit SlicePlugin(std::vector<int> starts, std::vector<int> ends,
32
                       std::vector<int> axes, bool with_fp16);
33 34 35 36 37

  // It was used for tensorrt deserialization.
  // It should not be called by users.
  SlicePlugin(void const* serial_data, size_t serial_length);
  ~SlicePlugin();
38
  SlicePlugin* clone() const TRT_NOEXCEPT override;
39

40 41 42 43 44 45 46
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "slice_plugin";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
  int initialize() TRT_NOEXCEPT override { return 0; }
  bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
      const TRT_NOEXCEPT override;
47
  nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
48
                                     int nb_input_dims) TRT_NOEXCEPT override;
49
#if IS_TRT_VERSION_LT(8000)
50
  int enqueue(int batch_size, const void* const* inputs, void** outputs,
51 52 53
#else
  int enqueue(int batch_size, const void* const* inputs, void* const* outputs,
#endif
54
              void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
55

56
  size_t getSerializationSize() const TRT_NOEXCEPT override;
57 58 59

  // TRT will call this func  to serialize the configuration of TRT
  // It should not be called by users.
60
  void serialize(void* buffer) const TRT_NOEXCEPT override;
61 62 63 64 65 66 67 68 69 70

 private:
  std::vector<int> starts_;
  std::vector<int> ends_;
  std::vector<int> axes_;
  int* offset_temp_data_{nullptr};
  cudaEvent_t copy_event_;
  cudaStream_t copy_stream_;
};

71 72
class SlicePluginCreator : public TensorRTPluginCreator {
 public:
73 74 75
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "slice_plugin";
  }
76

77
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
78

79 80 81
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serial_data,
      size_t serial_length) TRT_NOEXCEPT override {
82 83 84 85 86
    return new SlicePlugin(serial_data, serial_length);
  }
};
REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);

87 88 89 90
#if IS_TRT_VERSION_GE(6000)
class SlicePluginDynamic : public DynamicPluginTensorRT {
 public:
  explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends,
91
                              std::vector<int> axes, bool with_fp16);
92

93
  nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
94
    return new SlicePluginDynamic(starts_, ends_, axes_, with_fp16_);
95 96
  }

97 98
  SlicePluginDynamic(void const* serialData, size_t serialLength);

99 100 101 102 103
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "slice_plugin_dynamic";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
  int initialize() TRT_NOEXCEPT override;
104

105 106
  size_t getSerializationSize() const TRT_NOEXCEPT override;
  void serialize(void* buffer) const TRT_NOEXCEPT override;
107 108 109

  nvinfer1::DimsExprs getOutputDimensions(
      int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
110
      nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
111 112 113

  bool supportsFormatCombination(int pos,
                                 const nvinfer1::PluginTensorDesc* inOut,
114 115
                                 int nbInputs,
                                 int nbOutputs) TRT_NOEXCEPT override;
116 117 118 119

  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
                       int nbInputs,
                       const nvinfer1::DynamicPluginTensorDesc* out,
120
                       int nbOutputs) TRT_NOEXCEPT override {}
121 122 123 124

  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
                          int nbInputs,
                          const nvinfer1::PluginTensorDesc* outputs,
125
                          int nbOutputs) const TRT_NOEXCEPT override {
126 127 128 129 130 131
    return 0;
  }

  int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
              const nvinfer1::PluginTensorDesc* outputDesc,
              const void* const* inputs, void* const* outputs, void* workspace,
132 133 134 135
              cudaStream_t stream) TRT_NOEXCEPT override;
  nvinfer1::DataType getOutputDataType(
      int index, const nvinfer1::DataType* inputTypes,
      int nbInputs) const TRT_NOEXCEPT override;
136

137
  void destroy() TRT_NOEXCEPT override;
138 139 140 141 142

 private:
  std::vector<int> starts_;
  std::vector<int> ends_;
  std::vector<int> axes_;
143 144 145
  int* offset_temp_data_{nullptr};
  cudaEvent_t copy_event_;
  cudaStream_t copy_stream_;
146
};
147

148
class SlicePluginDynamicCreator : public TensorRTPluginCreator {
149
 public:
150 151 152
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "slice_plugin_dynamic";
  }
153

154
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
155

156 157 158
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serialData,
      size_t serialLength) TRT_NOEXCEPT override {
159
    return new SlicePluginDynamic(serialData, serialLength);
160 161
  }
};
W
Wilber 已提交
162
REGISTER_TRT_PLUGIN_V2(SlicePluginDynamicCreator);
163

164 165 166 167 168 169
#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle