// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, softwarepool // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.h" #include "paddle/fluid/operators/math/pooling.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { size_t Pool3DPlugin::getSerializationSize() const TRT_NOEXCEPT { return getBaseSerializationSize() + SerializedSize(ceil_mode_) + SerializedSize(pool3d_type_) + SerializedSize(adaptive_) + SerializedSize(ksize_) + SerializedSize(strides_) + SerializedSize(paddings_) + SerializedSize(input_shape_) + SerializedSize(output_shape_); } // TRT will call this func when we need to serialize the configuration of // tensorrt. void Pool3DPlugin::serialize(void *buffer) const TRT_NOEXCEPT { serializeBase(buffer); SerializeValue(&buffer, ceil_mode_); SerializeValue(&buffer, pool3d_type_); SerializeValue(&buffer, adaptive_); SerializeValue(&buffer, ksize_); SerializeValue(&buffer, strides_); SerializeValue(&buffer, paddings_); SerializeValue(&buffer, input_shape_); SerializeValue(&buffer, output_shape_); } Pool3DPlugin *Pool3DPlugin::clone() const TRT_NOEXCEPT { return new Pool3DPlugin(ceil_mode_, pool3d_type_, adaptive_, ksize_, strides_, paddings_, input_shape_); } const char *Pool3DPlugin::getPluginType() const TRT_NOEXCEPT { return "pool3d_plugin"; } int Pool3DPlugin::getNbOutputs() const TRT_NOEXCEPT { return 1; } int Pool3DPlugin::initialize() TRT_NOEXCEPT { return 0; } nvinfer1::DataType Pool3DPlugin::getOutputDataType( int index, const nvinfer1::DataType *input_types, int nb_inputs) const TRT_NOEXCEPT { return input_types[0]; } void Pool3DPlugin::destroy() TRT_NOEXCEPT { delete this; } nvinfer1::Dims Pool3DPlugin::getOutputDimensions( int index, const nvinfer1::Dims *inputDims, int nbInputs) TRT_NOEXCEPT { PADDLE_ENFORCE_EQ(nbInputs, 1, platform::errors::InvalidArgument( "The Pool3D Plugin only has one input, so the nbInputs " "value should be 1, but get %d.", nbInputs)); PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( "The Pool3D Plugin only has one input, so " "the index value should be 0, but get %d.", index)); PADDLE_ENFORCE_EQ(inputDims[0].nbDims, 4, platform::errors::InvalidArgument( "The Pool3D Plugin only has four Dimensions, so the " "nbDims value should be 4, but get %d.", inputDims[0].nbDims)); nvinfer1::Dims const &input_dims = inputDims[0]; nvinfer1::Dims output_dims = input_dims; output_dims.d[1] = output_shape_[1]; output_dims.d[2] = output_shape_[2]; output_dims.d[3] = output_shape_[3]; return output_dims; } int Pool3DPlugin::enqueue(int batchSize, const void *const *inputs, #if IS_TRT_VERSION_LT(8000) void **outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT { #else void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT { #endif int input_size = 0; float const *idata = reinterpret_cast(inputs[0]); float *const *odatas = reinterpret_cast(outputs); std::vector input_shape = input_shape_; std::vector output_shape = output_shape_; input_shape.insert(input_shape.begin(), batchSize); output_shape.insert(output_shape.begin(), batchSize); if (pool3d_type_ == Pool3DType::max) { paddle::operators::math::MaxPool pool_process; paddle::operators::math::Pool3dDirectCUDAFunctor< paddle::operators::math::MaxPool, float> pool3d_forward; pool3d_forward(idata, input_shape, output_shape, ksize_, strides_, paddings_, true, adaptive_, odatas[0], stream, pool_process); } else if (pool3d_type_ == Pool3DType::avg) { paddle::operators::math::AvgPool pool_process; paddle::operators::math::Pool3dDirectCUDAFunctor< paddle::operators::math::AvgPool, float> pool3d_forward; pool3d_forward(idata, input_shape, output_shape, ksize_, strides_, paddings_, true, adaptive_, odatas[0], stream, pool_process); } return cudaGetLastError() != cudaSuccess; } // Dynamic Plugin below. Pool3DPluginDynamic::Pool3DPluginDynamic(void const *serialData, size_t serialLength) { DeserializeValue(&serialData, &serialLength, &ceil_mode_); const char *pool3d_type; DeserializeValue(&serialData, &serialLength, &pool3d_type); pool3d_type_ = std::string(pool3d_type); DeserializeValue(&serialData, &serialLength, &adaptive_); DeserializeValue(&serialData, &serialLength, &ksize_); DeserializeValue(&serialData, &serialLength, &strides_); DeserializeValue(&serialData, &serialLength, &paddings_); DeserializeValue(&serialData, &serialLength, &is_global_); } nvinfer1::IPluginV2DynamicExt *Pool3DPluginDynamic::clone() const TRT_NOEXCEPT { return new Pool3DPluginDynamic(ceil_mode_, pool3d_type_, adaptive_, ksize_, strides_, paddings_, is_global_); } const char *Pool3DPluginDynamic::getPluginType() const TRT_NOEXCEPT { return "pool3d_plugin_dynamic"; } int Pool3DPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } int Pool3DPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } void Pool3DPluginDynamic::configurePlugin( const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) TRT_NOEXCEPT {} size_t Pool3DPluginDynamic::getWorkspaceSize( const nvinfer1::PluginTensorDesc *inputs, int nbInputs, const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT { return 0; } size_t Pool3DPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { return SerializedSize(ceil_mode_) + SerializedSize(pool3d_type_.c_str()) + SerializedSize(adaptive_) + SerializedSize(ksize_) + SerializedSize(strides_) + SerializedSize(paddings_) + SerializedSize(is_global_); } void Pool3DPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { SerializeValue(&buffer, ceil_mode_); SerializeValue(&buffer, pool3d_type_.c_str()); SerializeValue(&buffer, adaptive_); SerializeValue(&buffer, ksize_); SerializeValue(&buffer, strides_); SerializeValue(&buffer, paddings_); SerializeValue(&buffer, is_global_); } nvinfer1::DimsExprs Pool3DPluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { PADDLE_ENFORCE_EQ(nb_inputs, 1, platform::errors::InvalidArgument( "The Split plugin should be only one input.")); PADDLE_ENFORCE_EQ( inputs[0].d[1]->isConstant(), true, platform::errors::InvalidArgument("The channel dimension should be " "static, but we found it's dynamic.")); nvinfer1::DimsExprs output(inputs[0]); if (is_global_) { output.d[2] = expr_builder.constant(1); output.d[3] = expr_builder.constant(1); output.d[4] = expr_builder.constant(1); return output; } if (adaptive_) { output.d[2] = expr_builder.constant(ksize_[0]); output.d[3] = expr_builder.constant(ksize_[1]); output.d[4] = expr_builder.constant(ksize_[2]); return output; } auto stri_0 = expr_builder.constant(strides_[0]); auto stri_1 = expr_builder.constant(strides_[1]); auto stri_2 = expr_builder.constant(strides_[2]); auto one_value = expr_builder.constant(1); auto v0_tmp = expr_builder.constant(-ksize_[0] + 2 * paddings_[0]); auto v1_tmp = expr_builder.constant(-ksize_[1] + 2 * paddings_[1]); auto v2_tmp = expr_builder.constant(-ksize_[2] + 2 * paddings_[2]); auto ceil_tmp = expr_builder.constant(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1); auto ceil1_tmp = expr_builder.constant(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1); auto ceil2_tmp = expr_builder.constant(-ksize_[2] + 2 * paddings_[2] + strides_[2] - 1); if (!ceil_mode_) { output.d[2] = expr_builder.operation( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[2], *v0_tmp), *stri_0), *one_value); output.d[3] = expr_builder.operation( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[3], *v1_tmp), *stri_1), *one_value); output.d[4] = expr_builder.operation( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[4], *v2_tmp), *stri_2), *one_value); } else { output.d[2] = expr_builder.operation( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[2], *ceil_tmp), *stri_0), *one_value); output.d[3] = expr_builder.operation( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[3], *ceil1_tmp), *stri_1), *one_value); output.d[4] = expr_builder.operation( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[4], *ceil2_tmp), *stri_2), *one_value); } return output; } bool Pool3DPluginDynamic::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, int nb_outputs) TRT_NOEXCEPT { PADDLE_ENFORCE_NOT_NULL( in_out, platform::errors::InvalidArgument( "The input of swish plugin shoule not be nullptr.")); PADDLE_ENFORCE_LT( pos, nb_inputs + nb_outputs, platform::errors::InvalidArgument("The pos(%d) should be less than the " "num(%d) of the input and the output.", pos, nb_inputs + nb_outputs)); (in_out && pos < (nb_inputs + nb_outputs)); return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && in_out[pos].format == nvinfer1::PluginFormat::kLINEAR); } nvinfer1::DataType Pool3DPluginDynamic::getOutputDataType( int index, const nvinfer1::DataType *input_types, int nb_inputs) const TRT_NOEXCEPT { PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( "The Pool3D Plugin only has one input, so the " "index value should be 0, but get %d.", index)); PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true, platform::errors::InvalidArgument( "The input type should be half or float")); return input_types[0]; } int Pool3DPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT { auto input_dims = input_desc[0].dims; int n = input_dims.d[0]; int c = input_dims.d[1]; int d = input_dims.d[2]; int h = input_dims.d[3]; int w = input_dims.d[4]; const float *input = static_cast(inputs[0]); float *output = static_cast(outputs[0]); std::vector input_shape, output_shape; for (int i = 0; i < input_dims.nbDims; i++) input_shape.push_back(input_dims.d[i]); output_shape = input_shape; std::vector ksize = ksize_; std::vector paddings = paddings_; if (is_global_) { ksize[0] = d; ksize[1] = h; ksize[2] = w; paddings[0] = 0; paddings[1] = 0; paddings[2] = 0; output_shape[2] = 1; output_shape[3] = 1; output_shape[4] = 1; } else { auto data_dim = CalcOutputSize({d, h, w}, ceil_mode_, adaptive_, ksize_, strides_, paddings_); output_shape[2] = data_dim[0]; output_shape[3] = data_dim[1]; output_shape[4] = data_dim[2]; } if (pool3d_type_ == "max") { paddle::operators::math::MaxPool pool_process; paddle::operators::math::Pool3dDirectCUDAFunctor< paddle::operators::math::MaxPool, float> pool3d_forward; pool3d_forward(input, input_shape, output_shape, ksize, strides_, paddings, true, adaptive_, output, stream, pool_process); } else if (pool3d_type_ == "avg") { paddle::operators::math::AvgPool pool_process; paddle::operators::math::Pool3dDirectCUDAFunctor< paddle::operators::math::AvgPool, float> pool3d_forward; pool3d_forward(input, input_shape, output_shape, ksize, strides_, paddings, true, adaptive_, output, stream, pool_process); } return cudaGetLastError() != cudaSuccess; } } // namespace plugin } // namespace tensorrt } // namespace inference } // namespace paddle