pool_op_plugin.cu 12.8 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h"
F
From00 已提交
16
#include "paddle/phi/kernels/funcs/pooling.h"
N
nhzlx 已提交
17 18 19 20

namespace paddle {
namespace inference {
namespace tensorrt {
N
nhzlx 已提交
21
namespace plugin {
N
nhzlx 已提交
22

23
nvinfer1::Dims PoolPlugin::getOutputDimensions(int index,
24
                                               const nvinfer1::Dims *inputDims,
25
                                               int nbInputs) TRT_NOEXCEPT {
N
nhzlx 已提交
26 27 28
  assert(nbInputs == 1);
  assert(index == 0);
  assert(inputDims[0].nbDims == 3);
29
  nvinfer1::Dims const &input_dims = inputDims[0];
N
nhzlx 已提交
30 31 32 33 34 35 36 37

  nvinfer1::Dims output_dims = input_dims;

  output_dims.d[1] = output_shape_[1];
  output_dims.d[2] = output_shape_[2];
  return output_dims;
}

F
feng_shuai 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
size_t PoolPlugin::getSerializationSize() const TRT_NOEXCEPT {
  return getBaseSerializationSize() + SerializedSize(ceil_mode_) +
         SerializedSize(pool_type_) + SerializedSize(adaptive_) +
         SerializedSize(exclusive_) + SerializedSize(ksize_) +
         SerializedSize(strides_) + SerializedSize(paddings_) +
         SerializedSize(real_paddings_) + SerializedSize(input_shape_) +
         SerializedSize(output_shape_);
}

// TRT will call this func when we need to serialize the configuration of
// tensorrt.
void PoolPlugin::serialize(void *buffer) const TRT_NOEXCEPT {
  serializeBase(buffer);
  SerializeValue(&buffer, ceil_mode_);
  SerializeValue(&buffer, pool_type_);
  SerializeValue(&buffer, adaptive_);
  SerializeValue(&buffer, exclusive_);
  SerializeValue(&buffer, ksize_);
  SerializeValue(&buffer, strides_);
  SerializeValue(&buffer, paddings_);
  SerializeValue(&buffer, real_paddings_);
  SerializeValue(&buffer, input_shape_);
  SerializeValue(&buffer, output_shape_);
}

PoolPlugin *PoolPlugin::clone() const TRT_NOEXCEPT {
64 65 66 67 68 69 70 71 72
  return new PoolPlugin(ceil_mode_,
                        pool_type_,
                        adaptive_,
                        exclusive_,
                        ksize_,
                        strides_,
                        paddings_,
                        input_shape_,
                        real_paddings_);
F
feng_shuai 已提交
73 74
}

75 76
int PoolPlugin::enqueue(int batchSize,
                        const void *const *inputs,
77
#if IS_TRT_VERSION_LT(8000)
78 79
                        void **outputs,
                        void *workspace,
80
                        cudaStream_t stream) TRT_NOEXCEPT {
81
#else
82 83
                        void *const *outputs,
                        void *workspace,
84
                        cudaStream_t stream) TRT_NOEXCEPT {
85
#endif
86
  auto const &input_dims = this->getInputDims(0);
N
nhzlx 已提交
87
  int input_size = 0;
88
  float const *idata = reinterpret_cast<float const *>(inputs[0]);
89
  float *const *odatas = reinterpret_cast<float *const *>(outputs);
N
nhzlx 已提交
90 91 92 93 94 95

  std::vector<int> input_shape = input_shape_;
  std::vector<int> output_shape = output_shape_;
  input_shape.insert(input_shape.begin(), batchSize);
  output_shape.insert(output_shape.begin(), batchSize);

96
  if (pool_type_ == PoolType::max) {
F
From00 已提交
97 98
    phi::funcs::MaxPool<float> pool_process;
    phi::funcs::Pool2dDirectCUDAFunctor<phi::funcs::MaxPool<float>, float>
99
        pool2d_forward;
100 101 102 103 104 105 106 107 108 109 110
    pool2d_forward(idata,
                   input_shape,
                   output_shape,
                   ksize_,
                   strides_,
                   paddings_,
                   true,
                   false,
                   odatas[0],
                   stream,
                   pool_process);
111
  } else if (pool_type_ == PoolType::avg) {
F
From00 已提交
112 113
    phi::funcs::AvgPool<float> pool_process;
    phi::funcs::Pool2dDirectCUDAFunctor<phi::funcs::AvgPool<float>, float>
114
        pool2d_forward;
115 116 117 118 119 120 121 122 123 124
    pool2d_forward(idata,
                   input_shape,
                   output_shape,
                   ksize_,
                   strides_,
                   paddings_,
                   exclusive_,
                   adaptive_,
                   odatas[0],
                   stream,
F
feng_shuai 已提交
125
                   pool_process);
126
  }
N
nhzlx 已提交
127 128 129 130

  return cudaGetLastError() != cudaSuccess;
}

131 132 133
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)

134 135 136 137 138 139 140
PoolPluginDynamic::PoolPluginDynamic(void const *serialData,
                                     size_t serialLength) {
  DeserializeValue(&serialData, &serialLength, &ceil_mode_);
  const char *pool_type;
  DeserializeValue(&serialData, &serialLength, &pool_type);
  pool_type_ = std::string(pool_type);
  DeserializeValue(&serialData, &serialLength, &adaptive_);
F
feng_shuai 已提交
141
  DeserializeValue(&serialData, &serialLength, &exclusive_);
142 143 144 145 146 147
  DeserializeValue(&serialData, &serialLength, &ksize_);
  DeserializeValue(&serialData, &serialLength, &strides_);
  DeserializeValue(&serialData, &serialLength, &paddings_);
  DeserializeValue(&serialData, &serialLength, &is_global_);
}

148
size_t PoolPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
149
  return SerializedSize(ceil_mode_) + SerializedSize(pool_type_.c_str()) +
F
feng_shuai 已提交
150 151 152
         SerializedSize(adaptive_) + SerializedSize(exclusive_) +
         SerializedSize(ksize_) + SerializedSize(strides_) +
         SerializedSize(paddings_) + SerializedSize(is_global_);
153
}
154

155
void PoolPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
156 157 158
  SerializeValue(&buffer, ceil_mode_);
  SerializeValue(&buffer, pool_type_.c_str());
  SerializeValue(&buffer, adaptive_);
F
feng_shuai 已提交
159
  SerializeValue(&buffer, exclusive_);
160 161 162 163 164
  SerializeValue(&buffer, ksize_);
  SerializeValue(&buffer, strides_);
  SerializeValue(&buffer, paddings_);
  SerializeValue(&buffer, is_global_);
}
165

F
feng_shuai 已提交
166
nvinfer1::IPluginV2DynamicExt *PoolPluginDynamic::clone() const TRT_NOEXCEPT {
167 168 169 170 171 172 173 174
  return new PoolPluginDynamic(ceil_mode_,
                               pool_type_,
                               adaptive_,
                               exclusive_,
                               ksize_,
                               strides_,
                               paddings_,
                               is_global_);
F
feng_shuai 已提交
175 176
}

177
nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions(
178 179 180
    int output_index,
    const nvinfer1::DimsExprs *inputs,
    int nb_inputs,
181
    nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
182 183
  PADDLE_ENFORCE_EQ(nb_inputs,
                    1,
184 185 186 187
                    platform::errors::InvalidArgument(
                        "The Split plugin should be only one input."));

  nvinfer1::DimsExprs output(inputs[0]);
F
feng_shuai 已提交
188
  if (is_global_ && !adaptive_) {
189 190 191 192
    output.d[2] = expr_builder.constant(1);
    output.d[3] = expr_builder.constant(1);
    return output;
  }
F
feng_shuai 已提交
193 194 195
  if (is_global_ && adaptive_) {
    return inputs[0];
  }
196 197 198 199 200 201 202 203
  if (adaptive_) {
    output.d[2] = expr_builder.constant(ksize_[0]);
    output.d[3] = expr_builder.constant(ksize_[1]);
    return output;
  }

  auto stri_0 = expr_builder.constant(strides_[0]);
  auto stri_1 = expr_builder.constant(strides_[1]);
Z
Zhaolong Xing 已提交
204
  auto one_value = expr_builder.constant(1);
205

Z
Zhaolong Xing 已提交
206 207
  auto v0_tmp = expr_builder.constant(-ksize_[0] + 2 * paddings_[0]);
  auto v1_tmp = expr_builder.constant(-ksize_[1] + 2 * paddings_[1]);
208

Z
Zhaolong Xing 已提交
209 210 211 212
  auto ceil_tmp =
      expr_builder.constant(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1);
  auto ceil1_tmp =
      expr_builder.constant(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1);
213 214

  if (!ceil_mode_) {
Z
Zhaolong Xing 已提交
215 216 217 218
    output.d[2] = expr_builder.operation(
        nvinfer1::DimensionOperation::kSUM,
        *expr_builder.operation(
            nvinfer1::DimensionOperation::kFLOOR_DIV,
219 220
            *expr_builder.operation(
                nvinfer1::DimensionOperation::kSUM, *inputs[0].d[2], *v0_tmp),
Z
Zhaolong Xing 已提交
221 222 223 224 225 226
            *stri_0),
        *one_value);
    output.d[3] = expr_builder.operation(
        nvinfer1::DimensionOperation::kSUM,
        *expr_builder.operation(
            nvinfer1::DimensionOperation::kFLOOR_DIV,
227 228
            *expr_builder.operation(
                nvinfer1::DimensionOperation::kSUM, *inputs[0].d[3], *v1_tmp),
Z
Zhaolong Xing 已提交
229 230 231
            *stri_1),
        *one_value);

232
  } else {
Z
Zhaolong Xing 已提交
233 234 235 236
    output.d[2] = expr_builder.operation(
        nvinfer1::DimensionOperation::kSUM,
        *expr_builder.operation(
            nvinfer1::DimensionOperation::kFLOOR_DIV,
237 238
            *expr_builder.operation(
                nvinfer1::DimensionOperation::kSUM, *inputs[0].d[2], *ceil_tmp),
Z
Zhaolong Xing 已提交
239 240 241 242 243 244 245
            *stri_0),
        *one_value);
    output.d[3] = expr_builder.operation(
        nvinfer1::DimensionOperation::kSUM,
        *expr_builder.operation(
            nvinfer1::DimensionOperation::kFLOOR_DIV,
            *expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
246 247
                                    *inputs[0].d[3],
                                    *ceil1_tmp),
Z
Zhaolong Xing 已提交
248 249
            *stri_1),
        *one_value);
250 251 252 253 254 255
  }

  return output;
}

bool PoolPluginDynamic::supportsFormatCombination(
256 257 258
    int pos,
    const nvinfer1::PluginTensorDesc *in_out,
    int nb_inputs,
259
    int nb_outputs) TRT_NOEXCEPT {
260
  PADDLE_ENFORCE_NOT_NULL(
261 262 263
      in_out,
      platform::errors::InvalidArgument(
          "The input of swish plugin shoule not be nullptr."));
264 265

  PADDLE_ENFORCE_LT(
266 267
      pos,
      nb_inputs + nb_outputs,
268 269
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
270 271
                                        pos,
                                        nb_inputs + nb_outputs));
272 273 274
  (in_out && pos < (nb_inputs + nb_outputs));

  return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
275
          in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
276 277 278
}

nvinfer1::DataType PoolPluginDynamic::getOutputDataType(
279 280
    int index,
    const nvinfer1::DataType *input_types,
281
    int nb_inputs) const TRT_NOEXCEPT {
282 283
  PADDLE_ENFORCE_EQ(index,
                    0,
284 285 286 287
                    platform::errors::InvalidArgument(
                        "The Pool Plugin only has one input, so the "
                        "index value should be 0, but get %d.",
                        index));
288 289
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT),
                    true,
290 291 292 293 294 295 296
                    platform::errors::InvalidArgument(
                        "The input type should be half or float"));
  return input_types[0];
}

int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                               const nvinfer1::PluginTensorDesc *output_desc,
297 298
                               const void *const *inputs,
                               void *const *outputs,
299 300
                               void *workspace,
                               cudaStream_t stream) TRT_NOEXCEPT {
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
  auto input_dims = input_desc[0].dims;
  int n = input_dims.d[0];
  int c = input_dims.d[1];
  int h = input_dims.d[2];
  int w = input_dims.d[3];

  const float *input = static_cast<const float *>(inputs[0]);
  float *output = static_cast<float *>(outputs[0]);

  std::vector<int> input_shape, output_shape;
  for (int i = 0; i < input_dims.nbDims; i++)
    input_shape.push_back(input_dims.d[i]);
  output_shape = input_shape;

  std::vector<int> ksize = ksize_;
  std::vector<int> paddings = paddings_;
  if (is_global_) {
    ksize[0] = h;
    ksize[1] = w;
    paddings[0] = 0;
    paddings[1] = 0;
    output_shape[2] = 1;
    output_shape[3] = 1;
  } else {
325 326
    auto data_dim = CalcOutputSize(
        {h, w}, ceil_mode_, adaptive_, ksize_, strides_, paddings_);
327 328 329
    output_shape[2] = data_dim[0];
    output_shape[3] = data_dim[1];
  }
F
feng_shuai 已提交
330 331 332 333
  if (adaptive_) {
    output_shape[2] = h;
    output_shape[3] = w;
  }
334 335

  if (pool_type_ == "max") {
F
From00 已提交
336 337
    phi::funcs::MaxPool<float> pool_process;
    phi::funcs::Pool2dDirectCUDAFunctor<phi::funcs::MaxPool<float>, float>
338
        pool2d_forward;
339 340 341 342 343 344 345 346 347 348 349
    pool2d_forward(input,
                   input_shape,
                   output_shape,
                   ksize,
                   strides_,
                   paddings,
                   true,
                   false,
                   output,
                   stream,
                   pool_process);
350
  } else if (pool_type_ == "avg") {
F
From00 已提交
351 352
    phi::funcs::AvgPool<float> pool_process;
    phi::funcs::Pool2dDirectCUDAFunctor<phi::funcs::AvgPool<float>, float>
353
        pool2d_forward;
354 355 356 357 358 359 360 361 362 363 364
    pool2d_forward(input,
                   input_shape,
                   output_shape,
                   ksize,
                   strides_,
                   paddings,
                   exclusive_,
                   adaptive_,
                   output,
                   stream,
                   pool_process);
365 366 367 368 369 370
  }

  return cudaGetLastError() != cudaSuccess;
}
#endif

N
nhzlx 已提交
371
}  // namespace plugin
N
nhzlx 已提交
372 373 374
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle