From b7a86e92a843ff53d7be9bad0737a7169d77977d Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 19 Aug 2020 13:04:18 +0800 Subject: [PATCH] fix dy shape bug in trt7.1 (#26273) test=develop --- paddle/fluid/inference/tensorrt/engine.h | 7 ++++++- .../tensorrt/plugin/emb_eltwise_layernorm_plugin.cu | 12 +++++++++++- .../tensorrt/plugin/emb_eltwise_layernorm_plugin.h | 7 ++++++- .../inference/tensorrt/plugin/prelu_op_plugin.cu | 6 ++++++ .../inference/tensorrt/plugin/prelu_op_plugin.h | 5 ++++- .../tensorrt/plugin/skip_layernorm_op_plugin.h | 5 ++++- 6 files changed, 37 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index fdd71b0d88..1a3413657c 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -83,7 +83,12 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, } else if (shape.size() == 3UL) { return nvinfer1::Dims3(shape[0], shape[1], shape[2]); } - return nvinfer1::Dims4(shape[0], shape[1], 1, 1); + nvinfer1::Dims dims; + dims.nbDims = shape.size(); + for (size_t i = 0; i < shape.size(); i++) { + dims.d[i] = shape[i]; + } + return dims; } } } // NOLINT diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu index e7f9381e97..5e43be90de 100644 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu @@ -76,6 +76,16 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions( return ret; } +template +void EmbEltwiseLayernormPluginDynamic::terminate() { + for (auto ptr : embs_gpu_) { + if (ptr) cudaFree(ptr); + } + + if (bias_gpu_) cudaFree(bias_gpu_); + if (scale_gpu_) cudaFree(scale_gpu_); +} + template bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, @@ -153,7 +163,7 @@ int EmbEltwiseLayernormPluginDynamic::enqueue( int64_t *emb_ptr_gpu_d = emb_ptr_tensor.mutable_data(platform::CUDAPlace(device_id)); - std::vector in_ptr, emb_ptr; + std::vector in_ptr, emb_ptr; for (int i = 0; i < input_num; i++) { in_ptr.push_back(reinterpret_cast(inputs[i])); emb_ptr.push_back(reinterpret_cast(embs_gpu_[i])); diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h index 8ac611cd7c..5babd87db0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h @@ -81,9 +81,13 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new EmbEltwiseLayernormPluginDynamic( + auto ptr = new EmbEltwiseLayernormPluginDynamic( embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_, eps_); + ptr->embs_gpu_ = embs_gpu_; + ptr->bias_gpu_ = bias_gpu_; + ptr->scale_gpu_ = scale_gpu_; + return ptr; } const char* getPluginType() const override { @@ -111,6 +115,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { return sum_num; } + void terminate() override; void serialize(void* buffer) const override { // SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, emb_sizes_); diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index f1e11b6fba..860f1039d5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -80,6 +80,12 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, #if IS_TRT_VERSION_GE(6000) +void PReluPluginDynamic::terminate() { + if (p_gpu_weight_) { + cudaFree(p_gpu_weight_); + } +} + int PReluPluginDynamic::initialize() { cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size()); cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float), diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h index 4756ca2e02..3126366c5f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -102,12 +102,15 @@ class PReluPluginDynamic : public DynamicPluginTensorRT { } ~PReluPluginDynamic() { cudaFree(p_gpu_weight_); } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + ptr->p_gpu_weight_ = p_gpu_weight_; + return ptr; } const char* getPluginType() const override { return "prelu_plugin"; } int getNbOutputs() const override { return 1; } int initialize() override; + void terminate() override; size_t getSerializationSize() const override; void serialize(void* buffer) const override; diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h index 8fe1edc4bf..24cd8e0368 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h @@ -51,8 +51,11 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT { } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new SkipLayerNormPluginDynamic( + auto ptr = new SkipLayerNormPluginDynamic( bias_.data(), scale_.data(), bias_size_, scale_size_, eps_, ban_fp16_); + ptr->bias_gpu_ = bias_gpu_; + ptr->scale_gpu_ = bias_gpu_; + return ptr; } const char* getPluginType() const override { return "skip_layernorm_plugin"; } -- GitLab