未验证 提交 e6095bc2 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix split trt plugin initialize (#30875)

* fix split trt plugin initialize

* update
上级 6e3856d3
......@@ -62,6 +62,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(
return output_dims;
}
void SplitPlugin::shareData(const SplitPlugin* another) {
outer_rows_ = another->outer_rows_;
inner_cols_ = another->inner_cols_;
same_shape_ = another->same_shape_;
axis_shape_ = another->axis_shape_;
d_segment_offsets_ = another->d_segment_offsets_;
segment_offsets_ = another->segment_offsets_;
d_output_ptrs_.resize(another->d_output_ptrs_.size(), nullptr);
}
int SplitPlugin::initialize() {
PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS,
platform::errors::InvalidArgument(
......@@ -93,6 +103,9 @@ int SplitPlugin::initialize() {
return 0;
}
// nothing to release according to initialize
void SplitPlugin::terminate() {}
// The following part of the code refers to onnx-tensorrt
// https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu
template <typename T>
......
......@@ -40,7 +40,9 @@ class SplitPlugin : public PluginTensorRT {
}
SplitPlugin* clone() const override {
return new SplitPlugin(axis_, output_length_, with_fp16_);
auto* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
ptr->shareData(this);
return ptr;
}
const char* getPluginType() const override { return "split_plugin"; }
......@@ -50,6 +52,7 @@ class SplitPlugin : public PluginTensorRT {
int num_inputs) override;
int initialize() override;
void terminate() override;
int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
......@@ -75,6 +78,9 @@ class SplitPlugin : public PluginTensorRT {
std::vector<int> segment_offsets_;
thrust::device_vector<int> d_segment_offsets_;
thrust::device_vector<float*> d_output_ptrs_;
private:
void shareData(const SplitPlugin* another);
};
#if IS_TRT_VERSION_GE(6000)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册