未验证 提交 e6095bc2 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix split trt plugin initialize (#30875)

* fix split trt plugin initialize

* update
上级 6e3856d3
...@@ -62,6 +62,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions( ...@@ -62,6 +62,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(
return output_dims; return output_dims;
} }
void SplitPlugin::shareData(const SplitPlugin* another) {
outer_rows_ = another->outer_rows_;
inner_cols_ = another->inner_cols_;
same_shape_ = another->same_shape_;
axis_shape_ = another->axis_shape_;
d_segment_offsets_ = another->d_segment_offsets_;
segment_offsets_ = another->segment_offsets_;
d_output_ptrs_.resize(another->d_output_ptrs_.size(), nullptr);
}
int SplitPlugin::initialize() { int SplitPlugin::initialize() {
PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS, PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -93,6 +103,9 @@ int SplitPlugin::initialize() { ...@@ -93,6 +103,9 @@ int SplitPlugin::initialize() {
return 0; return 0;
} }
// nothing to release according to initialize
void SplitPlugin::terminate() {}
// The following part of the code refers to onnx-tensorrt // The following part of the code refers to onnx-tensorrt
// https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu // https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu
template <typename T> template <typename T>
......
...@@ -40,7 +40,9 @@ class SplitPlugin : public PluginTensorRT { ...@@ -40,7 +40,9 @@ class SplitPlugin : public PluginTensorRT {
} }
SplitPlugin* clone() const override { SplitPlugin* clone() const override {
return new SplitPlugin(axis_, output_length_, with_fp16_); auto* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
ptr->shareData(this);
return ptr;
} }
const char* getPluginType() const override { return "split_plugin"; } const char* getPluginType() const override { return "split_plugin"; }
...@@ -50,6 +52,7 @@ class SplitPlugin : public PluginTensorRT { ...@@ -50,6 +52,7 @@ class SplitPlugin : public PluginTensorRT {
int num_inputs) override; int num_inputs) override;
int initialize() override; int initialize() override;
void terminate() override;
int enqueue(int batchSize, const void* const* inputs, void** outputs, int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override; void* workspace, cudaStream_t stream) override;
...@@ -75,6 +78,9 @@ class SplitPlugin : public PluginTensorRT { ...@@ -75,6 +78,9 @@ class SplitPlugin : public PluginTensorRT {
std::vector<int> segment_offsets_; std::vector<int> segment_offsets_;
thrust::device_vector<int> d_segment_offsets_; thrust::device_vector<int> d_segment_offsets_;
thrust::device_vector<float*> d_output_ptrs_; thrust::device_vector<float*> d_output_ptrs_;
private:
void shareData(const SplitPlugin* another);
}; };
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册