pool_op_plugin.h 8.7 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <stdio.h>
N
nhzlx 已提交
17
#include <cassert>
18
#include <string>
N
nhzlx 已提交
19 20 21 22 23 24
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
N
nhzlx 已提交
25
namespace plugin {
N
nhzlx 已提交
26

27 28 29 30 31
static std::vector<int> CalcOutputSize(const std::vector<int>& input_shape,
                                       const bool& ceil_mode,
                                       const bool& adaptive,
                                       const std::vector<int>& ksize,
                                       const std::vector<int>& strides,
F
feng_shuai 已提交
32
                                       const std::vector<int>& real_paddings) {
33 34 35 36 37
  std::vector<int> output_shape = input_shape;
  if (adaptive) {
    output_shape[0] = ksize[0];
    output_shape[1] = ksize[1];
  } else {
F
feng_shuai 已提交
38 39 40 41 42 43 44 45 46 47
    int output_h = 0, output_w = 0;
    if (ceil_mode) {
      output_h = (input_shape[0] - ksize[0] + real_paddings[0] +
                  real_paddings[1] + strides[0] - 1) /
                     strides[0] +
                 1;
      output_w = (input_shape[1] - ksize[1] + real_paddings[2] +
                  real_paddings[3] + strides[1] - 1) /
                     strides[1] +
                 1;
48
    }
F
feng_shuai 已提交
49 50 51 52 53 54 55 56 57
    // TRT will use native layer when ceil_model=false
    /*
    else{
      output_h = (input_shape[0] - ksize[0] + real_paddings[0] +
    real_paddings[1]) / strides[0] + 1;
      output_w = (input_shape[1] - ksize[1] + real_paddings[2] +
    real_paddings[3]) / strides[1] + 1;
    }
    */
58 59 60 61 62 63
    output_shape[0] = output_h;
    output_shape[1] = output_w;
  }
  return output_shape;
}

64
class PoolPlugin : public PluginTensorRT {
65
 public:
F
feng_shuai 已提交
66
  size_t getSerializationSize() const TRT_NOEXCEPT override;
N
nhzlx 已提交
67

F
feng_shuai 已提交
68
  void serialize(void* buffer) const TRT_NOEXCEPT override;
N
nhzlx 已提交
69

70 71 72 73 74
  enum class PoolType {
    max = 0,
    avg,
  };
  PoolPlugin() {}
F
feng_shuai 已提交
75
  PoolPlugin(bool ceil_mode, PoolType pool_type, bool adaptive, bool exclusive,
76
             std::vector<int> ksize, std::vector<int> strides,
F
feng_shuai 已提交
77 78
             std::vector<int> paddings, std::vector<int> input_shape,
             std::vector<int> real_paddings)
N
nhzlx 已提交
79
      : ceil_mode_(ceil_mode),
80 81
        pool_type_(pool_type),
        adaptive_(adaptive),
F
feng_shuai 已提交
82
        exclusive_(exclusive),
N
nhzlx 已提交
83 84 85
        ksize_(ksize),
        strides_(strides),
        paddings_(paddings),
F
feng_shuai 已提交
86
        real_paddings_(real_paddings),
N
nhzlx 已提交
87 88
        input_shape_(input_shape) {
    output_shape_ = input_shape_;
89 90
    std::vector<int> output_shape =
        CalcOutputSize({input_shape_[1], input_shape_[2]}, ceil_mode_,
F
feng_shuai 已提交
91
                       adaptive_, ksize_, strides_, real_paddings_);
92 93
    output_shape_[1] = output_shape[0];
    output_shape_[2] = output_shape[1];
N
nhzlx 已提交
94 95 96 97
  }

  // It was used for tensorrt deserialization.
  // It should not be called by users.
98
  PoolPlugin(void const* serialData, size_t serialLength) {
N
nhzlx 已提交
99 100
    deserializeBase(serialData, serialLength);
    DeserializeValue(&serialData, &serialLength, &ceil_mode_);
101 102
    DeserializeValue(&serialData, &serialLength, &pool_type_);
    DeserializeValue(&serialData, &serialLength, &adaptive_);
F
feng_shuai 已提交
103
    DeserializeValue(&serialData, &serialLength, &exclusive_);
N
nhzlx 已提交
104 105 106
    DeserializeValue(&serialData, &serialLength, &ksize_);
    DeserializeValue(&serialData, &serialLength, &strides_);
    DeserializeValue(&serialData, &serialLength, &paddings_);
F
feng_shuai 已提交
107
    DeserializeValue(&serialData, &serialLength, &real_paddings_);
N
nhzlx 已提交
108
    DeserializeValue(&serialData, &serialLength, &input_shape_);
N
nhzlx 已提交
109
    DeserializeValue(&serialData, &serialLength, &output_shape_);
N
nhzlx 已提交
110 111
  }

F
feng_shuai 已提交
112
  PoolPlugin* clone() const TRT_NOEXCEPT override;
N
nhzlx 已提交
113

114 115 116 117
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "pool_plugin";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
118
  nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
119 120
                                     int nbInputDims) TRT_NOEXCEPT override;
  int initialize() TRT_NOEXCEPT override { return 0; }
121
#if IS_TRT_VERSION_LT(8000)
122
  int enqueue(int batchSize, const void* const* inputs, void** outputs,
123 124 125
#else
  int enqueue(int batchSize, const void* const* inputs, void* const* outputs,
#endif
126
              void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
127 128 129 130 131

 private:
  bool ceil_mode_;
  PoolType pool_type_;
  bool adaptive_;
F
feng_shuai 已提交
132
  bool exclusive_;
133 134 135
  std::vector<int> ksize_;
  std::vector<int> strides_;
  std::vector<int> paddings_;
F
feng_shuai 已提交
136
  std::vector<int> real_paddings_;
137 138
  std::vector<int> input_shape_;
  std::vector<int> output_shape_;
N
nhzlx 已提交
139 140
};

141 142
class PoolPluginCreator : public TensorRTPluginCreator {
 public:
143 144 145
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "pool_plugin";
  }
146

147
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
148

149 150 151
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serial_data,
      size_t serial_length) TRT_NOEXCEPT override {
152 153 154 155 156
    return new PoolPlugin(serial_data, serial_length);
  }
};
REGISTER_TRT_PLUGIN_V2(PoolPluginCreator);

157 158 159 160 161
#if IS_TRT_VERSION_GE(6000)
class PoolPluginDynamic : public DynamicPluginTensorRT {
 public:
  PoolPluginDynamic() {}
  PoolPluginDynamic(const bool& ceil_mode, const std::string& pool_type,
F
feng_shuai 已提交
162 163
                    const bool& adaptive, bool exclusive,
                    const std::vector<int>& ksize,
164 165 166 167 168
                    const std::vector<int>& strides,
                    const std::vector<int>& paddings, const bool& is_global)
      : ceil_mode_(ceil_mode),
        pool_type_(pool_type),
        adaptive_(adaptive),
F
feng_shuai 已提交
169
        exclusive_(exclusive),
170 171 172 173 174
        ksize_(ksize),
        strides_(strides),
        paddings_(paddings),
        is_global_(is_global) {}

175
  PoolPluginDynamic(void const* serialData, size_t serialLength);
176
  ~PoolPluginDynamic() {}
F
feng_shuai 已提交
177
  nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
178

179 180 181 182 183
  const char* getPluginType() const TRT_NOEXCEPT override {
    return "pool_plugin_dynamic";
  }
  int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
  int initialize() TRT_NOEXCEPT override { return 0; }
184

185 186
  size_t getSerializationSize() const TRT_NOEXCEPT override;
  void serialize(void* buffer) const TRT_NOEXCEPT override;
187 188 189

  nvinfer1::DimsExprs getOutputDimensions(
      int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
190
      nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
191 192 193

  bool supportsFormatCombination(int pos,
                                 const nvinfer1::PluginTensorDesc* inOut,
194 195
                                 int nbInputs,
                                 int nbOutputs) TRT_NOEXCEPT override;
196 197 198 199

  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
                       int nbInputs,
                       const nvinfer1::DynamicPluginTensorDesc* out,
200
                       int nbOutputs) TRT_NOEXCEPT override {}
201 202 203 204

  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
                          int nbInputs,
                          const nvinfer1::PluginTensorDesc* outputs,
205
                          int nbOutputs) const TRT_NOEXCEPT override {
206 207 208 209 210 211
    return 0;
  }

  int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
              const nvinfer1::PluginTensorDesc* outputDesc,
              const void* const* inputs, void* const* outputs, void* workspace,
212 213 214 215
              cudaStream_t stream) TRT_NOEXCEPT override;
  nvinfer1::DataType getOutputDataType(
      int index, const nvinfer1::DataType* inputTypes,
      int nbInputs) const TRT_NOEXCEPT override;
216

217
  void destroy() TRT_NOEXCEPT override { delete this; }
218 219 220 221 222

 private:
  bool ceil_mode_;
  std::string pool_type_;
  bool adaptive_;
F
feng_shuai 已提交
223
  bool exclusive_;
224 225 226 227 228
  std::vector<int> ksize_;
  std::vector<int> strides_;
  std::vector<int> paddings_;
  bool is_global_;
};
229 230 231

class PoolPluginDynamicCreator : public TensorRTPluginCreator {
 public:
232 233 234
  const char* getPluginName() const TRT_NOEXCEPT override {
    return "pool_plugin_dynamic";
  }
235

236
  const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
237

238 239 240
  nvinfer1::IPluginV2* deserializePlugin(
      const char* name, const void* serial_data,
      size_t serial_length) TRT_NOEXCEPT override {
241 242 243 244
    return new PoolPluginDynamic(serial_data, serial_length);
  }
};
REGISTER_TRT_PLUGIN_V2(PoolPluginDynamicCreator);
245 246
#endif

N
nhzlx 已提交
247
}  // namespace plugin
N
nhzlx 已提交
248 249 250
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle