From b779d2b8bb2dbe17987f7c490c487f3a430ea582 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 31 May 2022 11:27:12 +0800 Subject: [PATCH] fix slice plugin (#43110) --- .../tensorrt/plugin/slice_op_plugin.cu | 46 ++++++------------- .../tensorrt/plugin/slice_op_plugin.h | 6 +-- 2 files changed, 16 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index 4e6b82d2dc1..0a6d24f9072 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -56,8 +56,6 @@ SlicePlugin::SlicePlugin(std::vector starts, std::vector ends, std::vector axes, bool with_fp16) : starts_(starts), ends_(ends), axes_(axes) { with_fp16_ = with_fp16; - cudaEventCreate(©_event_); - cudaStreamCreate(©_stream_); } SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { @@ -66,15 +64,10 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { DeserializeValue(&serial_data, &serial_length, &ends_); DeserializeValue(&serial_data, &serial_length, &axes_); DeserializeValue(&serial_data, &serial_length, &with_fp16_); - cudaEventCreate(©_event_); - cudaStreamCreate(©_stream_); + DeserializeValue(&serial_data, &serial_length, &offset_info_); } -SlicePlugin::~SlicePlugin() { - cudaStreamDestroy(copy_stream_); - cudaEventDestroy(copy_event_); - cudaFree(offset_temp_data_); -} +SlicePlugin::~SlicePlugin() { cudaFree(offset_temp_data_); } SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT { return new SlicePlugin(starts_, ends_, axes_, with_fp16_); @@ -159,11 +152,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, } cudaMemcpyAsync(offset_temp_data_, offset_info.data(), - sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, - copy_stream_); - - cudaEventRecord(copy_event_, copy_stream_); - cudaStreamWaitEvent(stream, copy_event_, 0); + sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream); int threads = 256; int blocks = (out_num + threads - 1) / threads; @@ -190,7 +179,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT { return getBaseSerializationSize() + SerializedSize(starts_) + SerializedSize(ends_) + SerializedSize(axes_) + - SerializedSize(with_fp16_); + SerializedSize(with_fp16_) + SerializedSize(offset_info_); } void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { @@ -199,6 +188,7 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { SerializeValue(&buffer, ends_); SerializeValue(&buffer, axes_); SerializeValue(&buffer, with_fp16_); + SerializeValue(&buffer, offset_info_); } // Dynamic Plugin below. @@ -209,8 +199,6 @@ SlicePluginDynamic::SlicePluginDynamic(std::vector starts, bool with_fp16) : starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) { with_fp16_ = with_fp16; - cudaEventCreate(©_event_); - cudaStreamCreate(©_stream_); } SlicePluginDynamic::SlicePluginDynamic(void const *serialData, @@ -220,13 +208,10 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData, DeserializeValue(&serialData, &serialLength, &axes_); DeserializeValue(&serialData, &serialLength, &decrease_axis_); DeserializeValue(&serialData, &serialLength, &with_fp16_); - cudaEventCreate(©_event_); - cudaStreamCreate(©_stream_); + DeserializeValue(&serialData, &serialLength, &offset_info_); } void SlicePluginDynamic::destroy() TRT_NOEXCEPT { - cudaStreamDestroy(copy_stream_); - cudaEventDestroy(copy_event_); cudaFree(offset_temp_data_); delete this; } @@ -236,7 +221,7 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; } size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT { size_t size = SerializedSize(starts_) + SerializedSize(ends_) + SerializedSize(axes_) + SerializedSize(decrease_axis_) + - SerializedSize(with_fp16_); + SerializedSize(with_fp16_) + SerializedSize(offset_info_); return size; } @@ -247,6 +232,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { SerializeValue(&buffer, axes_); SerializeValue(&buffer, decrease_axis_); SerializeValue(&buffer, with_fp16_); + SerializeValue(&buffer, offset_info_); } nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( @@ -361,23 +347,19 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, offsets[axes_[i]] = starts_[i]; } - std::vector offset_info; + offset_info_.resize(num_dims * 3); for (size_t i = 0; i < num_dims; ++i) { - offset_info.push_back(offsets[i]); - offset_info.push_back(extends[i]); - offset_info.push_back(seg_offsets[i]); + offset_info_[i * 3 + 0] = offsets[i]; + offset_info_[i * 3 + 1] = extends[i]; + offset_info_[i * 3 + 2] = seg_offsets[i]; } if (offset_temp_data_ == nullptr) { cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); } - cudaMemcpyAsync(offset_temp_data_, offset_info.data(), - sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, - copy_stream_); - - cudaEventRecord(copy_event_, copy_stream_); - cudaStreamWaitEvent(stream, copy_event_, 0); + cudaMemcpyAsync(offset_temp_data_, offset_info_.data(), + sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream); int threads = 256; int blocks = (out_num + threads - 1) / threads; diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h index 4c07f0be368..6b50a52df1f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h @@ -64,8 +64,7 @@ class SlicePlugin : public PluginTensorRT { std::vector ends_; std::vector axes_; int* offset_temp_data_{nullptr}; - cudaEvent_t copy_event_; - cudaStream_t copy_stream_; + std::vector offset_info_; }; class SlicePluginCreator : public TensorRTPluginCreator { @@ -144,8 +143,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT { std::vector axes_; int decrease_axis_; int* offset_temp_data_{nullptr}; - cudaEvent_t copy_event_; - cudaStream_t copy_stream_; + std::vector offset_info_; }; class SlicePluginDynamicCreator : public TensorRTPluginCreator { -- GitLab