From 758fccfef134b5b1c20955ee75d002191723de46 Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Thu, 1 Dec 2022 18:57:12 +0800 Subject: [PATCH] [inference][trt] dynamic shape support for Instance norm (#47998) * instance norm support dynamic shape * update unittest --- .../tensorrt/convert/instance_norm_op.cc | 14 +- paddle/fluid/inference/tensorrt/op_teller.cc | 4 - .../plugin/instance_norm_op_plugin.cu | 109 ++++++++++++++ .../tensorrt/plugin/instance_norm_op_plugin.h | 134 +++++++++++++++++- .../test_trt_convert_instance_norm.py | 8 +- 5 files changed, 258 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc index aef91a9a69e..6a6e67328bb 100644 --- a/paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc @@ -74,10 +74,16 @@ class InstanceNormOpConverter : public OpConverter { bias_v.push_back(bias_d[i]); } - plugin::InstanceNormPlugin* plugin = - new plugin::InstanceNormPlugin(eps, scale_v, bias_v); - plugin->getPluginType(); - auto* layer = engine_->AddPlugin(&input, 1, plugin); + nvinfer1::IPluginV2* plugin = nullptr; + if (engine_->with_dynamic_shape()) { + plugin = new plugin::InstanceNormPluginDynamic(eps, scale_v, bias_v); + } else { + plugin = new plugin::InstanceNormPlugin(eps, scale_v, bias_v); + } + + std::vector instance_norm_inputs{input}; + auto* layer = engine_->network()->addPluginV2( + instance_norm_inputs.data(), instance_norm_inputs.size(), *plugin); auto output_name = op_desc.Output("Y")[0]; RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ce14e506dd6..17fb2f0aa6d 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1501,10 +1501,6 @@ struct SimpleOpTypeSetTeller : public Teller { } if (op_type == "instance_norm") { - if (with_dynamic_shape) { - VLOG(3) << "trt instance_norm op does not support dynamic shape "; - return false; - } if (desc.Input("X").size() != 1) { VLOG(3) << "input of instance_norm op converter should be 1, got " << desc.Input("X").size(); diff --git a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu index 6dd31dff016..82e24bea09a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu @@ -131,6 +131,115 @@ int InstanceNormPlugin::enqueue(int batch_size, return cudaGetLastError() != cudaSuccess; } +int InstanceNormPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } + +nvinfer1::DimsExprs InstanceNormPluginDynamic::getOutputDimensions( + int index, + const nvinfer1::DimsExprs *inputs, + int nbInputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::DimsExprs output(inputs[0]); + return output; +} + +bool InstanceNormPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + assert(inOut && pos < (nbInputs + nbOutputs)); + assert(pos == 0 || pos == 1); + return ((inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kHALF) && + (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR) && + inOut[pos].type == inOut[0].type); +} + +int InstanceNormPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + nvinfer1::Dims input_dims = inputDesc[0].dims; + int n = input_dims.d[0]; + int c = input_dims.d[1]; + int h = input_dims.d[2]; + int w = input_dims.d[3]; + + scale_t.Resize(phi::make_ddim({n, c})); + bias_t.Resize(phi::make_ddim({n, c})); + int device_id; + cudaGetDevice(&device_id); + float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id)); + float *bias_d = bias_t.mutable_data(platform::CUDAPlace(device_id)); + + for (int i = 0; i < n; i++) { + cudaMemcpyAsync(scale_d + i * c, + scale_.data(), + sizeof(float) * c, + cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(bias_d + i * c, + bias_.data(), + sizeof(float) * c, + cudaMemcpyHostToDevice, + stream); + } + platform::dynload::cudnnSetTensor4dDescriptor( + b_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1); + + cudnnDataType_t cudnn_dtype; + auto data_type = inputDesc[0].type; + convert_trt2cudnn_dtype(data_type, &cudnn_dtype); + platform::dynload::cudnnSetTensor4dDescriptor( + x_desc_, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); + platform::dynload::cudnnSetTensor4dDescriptor( + y_desc_, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); + float alpha = 1; + float beta = 0; + platform::dynload::cudnnSetStream(handle_, stream); + + void const *x_ptr = inputs[0]; + void *y_ptr = outputs[0]; + platform::dynload::cudnnBatchNormalizationForwardTraining( + handle_, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT, + &alpha, + &beta, + x_desc_, + x_ptr, + y_desc_, + y_ptr, + b_desc_, + scale_d, + bias_d, + 1., + nullptr, + nullptr, + eps_, + nullptr, + nullptr); + return cudaGetLastError() != cudaSuccess; +} + +nvinfer1::DataType InstanceNormPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT { + assert(inputTypes && nbInputs > 0 && index == 0); + return inputTypes[0]; +} + +void InstanceNormPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT {} + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h index 90a01d076f3..6a89139396c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h @@ -99,7 +99,7 @@ class InstanceNormPlugin : public PluginTensorRT { } const char *getPluginType() const TRT_NOEXCEPT override { - return "instance_norm_plugin"; + return "instance_norm"; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; } nvinfer1::Dims getOutputDimensions(int index, @@ -125,7 +125,7 @@ class InstanceNormPlugin : public PluginTensorRT { class InstanceNormPluginCreator : public TensorRTPluginCreator { public: const char *getPluginName() const TRT_NOEXCEPT override { - return "instance_norm_plugin"; + return "instance_norm"; } const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; } @@ -137,7 +137,137 @@ class InstanceNormPluginCreator : public TensorRTPluginCreator { return new InstanceNormPlugin(serial_data, serial_length); } }; + +class InstanceNormPluginDynamic : public DynamicPluginTensorRT { + private: + float eps_; + std::vector scale_; + std::vector bias_; + + phi::DenseTensor scale_t; + phi::DenseTensor bias_t; + cudnnHandle_t handle_; + cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_; + + public: + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(eps_) + SerializedSize(scale_) + + SerializedSize(bias_); + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + void serialize(void *buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, bias_); + } + + explicit InstanceNormPluginDynamic(const float eps, + const std::vector scale, + const std::vector bias) + : eps_(eps), scale_(scale), bias_(bias) { + PADDLE_ENFORCE_EQ(scale.size(), + bias.size(), + platform::errors::InvalidArgument( + "The instanceNorm's scale and bias should be the " + "same size. Got scale size = %d, but bias size = %d", + scale.size(), + bias.size())); + platform::dynload::cudnnCreate(&handle_); + platform::dynload::cudnnCreateTensorDescriptor(&x_desc_); + platform::dynload::cudnnCreateTensorDescriptor(&y_desc_); + platform::dynload::cudnnCreateTensorDescriptor(&b_desc_); + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + InstanceNormPluginDynamic(void const *serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &bias_); + + platform::dynload::cudnnCreate(&handle_); + platform::dynload::cudnnCreateTensorDescriptor(&x_desc_); + platform::dynload::cudnnCreateTensorDescriptor(&y_desc_); + platform::dynload::cudnnCreateTensorDescriptor(&b_desc_); + } + + ~InstanceNormPluginDynamic() { + platform::dynload::cudnnDestroy(handle_); + platform::dynload::cudnnDestroyTensorDescriptor(x_desc_); + platform::dynload::cudnnDestroyTensorDescriptor(y_desc_); + platform::dynload::cudnnDestroyTensorDescriptor(b_desc_); + } + + int initialize() TRT_NOEXCEPT override; + + nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override { + return new InstanceNormPluginDynamic(eps_, scale_, bias_); + } + + const char *getPluginType() const TRT_NOEXCEPT override { + return "instance_norm_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } +}; + +class InstanceNormPluginDynamicCreator : public TensorRTPluginCreator { + public: + const char *getPluginName() const TRT_NOEXCEPT override { + return "instance_norm_dynamic"; + } + + const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, + const void *serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new InstanceNormPluginDynamic(serial_data, serial_length); + } +}; + REGISTER_TRT_PLUGIN_V2(InstanceNormPluginCreator); +REGISTER_TRT_PLUGIN_V2(InstanceNormPluginDynamicCreator); } // namespace plugin } // namespace tensorrt diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py index a65588b8c5e..72b728d5cc3 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py @@ -50,7 +50,13 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest): [batch, 16, 32, 64], ]: self.in_dim = len(shape_input) - for epsilon in [0.0005, -1, 1]: + for epsilon in [ + 0.0005, + -1, + 1, + 0.000009999999747378752, + 0.00001, + ]: dics = [{"epsilon": epsilon}] ops_config = [ { -- GitLab