From e6095bc2cea4bc729851e3324b649803c34b711b Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Thu, 4 Feb 2021 15:36:37 +0800 Subject: [PATCH] fix split trt plugin initialize (#30875) * fix split trt plugin initialize * update --- .../inference/tensorrt/plugin/split_op_plugin.cu | 13 +++++++++++++ .../inference/tensorrt/plugin/split_op_plugin.h | 8 +++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index 2f4f731d887..256aa28206a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -62,6 +62,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions( return output_dims; } +void SplitPlugin::shareData(const SplitPlugin* another) { + outer_rows_ = another->outer_rows_; + inner_cols_ = another->inner_cols_; + same_shape_ = another->same_shape_; + axis_shape_ = another->axis_shape_; + d_segment_offsets_ = another->d_segment_offsets_; + segment_offsets_ = another->segment_offsets_; + d_output_ptrs_.resize(another->d_output_ptrs_.size(), nullptr); +} + int SplitPlugin::initialize() { PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS, platform::errors::InvalidArgument( @@ -93,6 +103,9 @@ int SplitPlugin::initialize() { return 0; } +// nothing to release according to initialize +void SplitPlugin::terminate() {} + // The following part of the code refers to onnx-tensorrt // https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu template diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index e3057f2bd18..5c47ec3a990 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -40,7 +40,9 @@ class SplitPlugin : public PluginTensorRT { } SplitPlugin* clone() const override { - return new SplitPlugin(axis_, output_length_, with_fp16_); + auto* ptr = new SplitPlugin(axis_, output_length_, with_fp16_); + ptr->shareData(this); + return ptr; } const char* getPluginType() const override { return "split_plugin"; } @@ -50,6 +52,7 @@ class SplitPlugin : public PluginTensorRT { int num_inputs) override; int initialize() override; + void terminate() override; int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override; @@ -75,6 +78,9 @@ class SplitPlugin : public PluginTensorRT { std::vector segment_offsets_; thrust::device_vector d_segment_offsets_; thrust::device_vector d_output_ptrs_; + + private: + void shareData(const SplitPlugin* another); }; #if IS_TRT_VERSION_GE(6000) -- GitLab