未验证 提交 b779d2b8 编写于 作者: W Wilber 提交者: GitHub

fix slice plugin (#43110)

上级 12d8a567
...@@ -56,8 +56,6 @@ SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends, ...@@ -56,8 +56,6 @@ SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool with_fp16) std::vector<int> axes, bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes) { : starts_(starts), ends_(ends), axes_(axes) {
with_fp16_ = with_fp16; with_fp16_ = with_fp16;
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
} }
SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
...@@ -66,15 +64,10 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { ...@@ -66,15 +64,10 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &ends_); DeserializeValue(&serial_data, &serial_length, &ends_);
DeserializeValue(&serial_data, &serial_length, &axes_); DeserializeValue(&serial_data, &serial_length, &axes_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_); DeserializeValue(&serial_data, &serial_length, &with_fp16_);
cudaEventCreate(&copy_event_); DeserializeValue(&serial_data, &serial_length, &offset_info_);
cudaStreamCreate(&copy_stream_);
} }
SlicePlugin::~SlicePlugin() { SlicePlugin::~SlicePlugin() { cudaFree(offset_temp_data_); }
cudaStreamDestroy(copy_stream_);
cudaEventDestroy(copy_event_);
cudaFree(offset_temp_data_);
}
SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT { SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
return new SlicePlugin(starts_, ends_, axes_, with_fp16_); return new SlicePlugin(starts_, ends_, axes_, with_fp16_);
...@@ -159,11 +152,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -159,11 +152,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
} }
cudaMemcpyAsync(offset_temp_data_, offset_info.data(), cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream);
copy_stream_);
cudaEventRecord(copy_event_, copy_stream_);
cudaStreamWaitEvent(stream, copy_event_, 0);
int threads = 256; int threads = 256;
int blocks = (out_num + threads - 1) / threads; int blocks = (out_num + threads - 1) / threads;
...@@ -190,7 +179,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -190,7 +179,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT { size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT {
return getBaseSerializationSize() + SerializedSize(starts_) + return getBaseSerializationSize() + SerializedSize(starts_) +
SerializedSize(ends_) + SerializedSize(axes_) + SerializedSize(ends_) + SerializedSize(axes_) +
SerializedSize(with_fp16_); SerializedSize(with_fp16_) + SerializedSize(offset_info_);
} }
void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
...@@ -199,6 +188,7 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { ...@@ -199,6 +188,7 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, ends_); SerializeValue(&buffer, ends_);
SerializeValue(&buffer, axes_); SerializeValue(&buffer, axes_);
SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, offset_info_);
} }
// Dynamic Plugin below. // Dynamic Plugin below.
...@@ -209,8 +199,6 @@ SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts, ...@@ -209,8 +199,6 @@ SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
bool with_fp16) bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) { : starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) {
with_fp16_ = with_fp16; with_fp16_ = with_fp16;
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
} }
SlicePluginDynamic::SlicePluginDynamic(void const *serialData, SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
...@@ -220,13 +208,10 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData, ...@@ -220,13 +208,10 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue(&serialData, &serialLength, &axes_); DeserializeValue(&serialData, &serialLength, &axes_);
DeserializeValue(&serialData, &serialLength, &decrease_axis_); DeserializeValue(&serialData, &serialLength, &decrease_axis_);
DeserializeValue(&serialData, &serialLength, &with_fp16_); DeserializeValue(&serialData, &serialLength, &with_fp16_);
cudaEventCreate(&copy_event_); DeserializeValue(&serialData, &serialLength, &offset_info_);
cudaStreamCreate(&copy_stream_);
} }
void SlicePluginDynamic::destroy() TRT_NOEXCEPT { void SlicePluginDynamic::destroy() TRT_NOEXCEPT {
cudaStreamDestroy(copy_stream_);
cudaEventDestroy(copy_event_);
cudaFree(offset_temp_data_); cudaFree(offset_temp_data_);
delete this; delete this;
} }
...@@ -236,7 +221,7 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; } ...@@ -236,7 +221,7 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT { size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
size_t size = SerializedSize(starts_) + SerializedSize(ends_) + size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_) + SerializedSize(decrease_axis_) + SerializedSize(axes_) + SerializedSize(decrease_axis_) +
SerializedSize(with_fp16_); SerializedSize(with_fp16_) + SerializedSize(offset_info_);
return size; return size;
} }
...@@ -247,6 +232,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { ...@@ -247,6 +232,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, axes_); SerializeValue(&buffer, axes_);
SerializeValue(&buffer, decrease_axis_); SerializeValue(&buffer, decrease_axis_);
SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, offset_info_);
} }
nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
...@@ -361,23 +347,19 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -361,23 +347,19 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
offsets[axes_[i]] = starts_[i]; offsets[axes_[i]] = starts_[i];
} }
std::vector<int> offset_info; offset_info_.resize(num_dims * 3);
for (size_t i = 0; i < num_dims; ++i) { for (size_t i = 0; i < num_dims; ++i) {
offset_info.push_back(offsets[i]); offset_info_[i * 3 + 0] = offsets[i];
offset_info.push_back(extends[i]); offset_info_[i * 3 + 1] = extends[i];
offset_info.push_back(seg_offsets[i]); offset_info_[i * 3 + 2] = seg_offsets[i];
} }
if (offset_temp_data_ == nullptr) { if (offset_temp_data_ == nullptr) {
cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
} }
cudaMemcpyAsync(offset_temp_data_, offset_info.data(), cudaMemcpyAsync(offset_temp_data_, offset_info_.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream);
copy_stream_);
cudaEventRecord(copy_event_, copy_stream_);
cudaStreamWaitEvent(stream, copy_event_, 0);
int threads = 256; int threads = 256;
int blocks = (out_num + threads - 1) / threads; int blocks = (out_num + threads - 1) / threads;
......
...@@ -64,8 +64,7 @@ class SlicePlugin : public PluginTensorRT { ...@@ -64,8 +64,7 @@ class SlicePlugin : public PluginTensorRT {
std::vector<int> ends_; std::vector<int> ends_;
std::vector<int> axes_; std::vector<int> axes_;
int* offset_temp_data_{nullptr}; int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_; std::vector<int> offset_info_;
cudaStream_t copy_stream_;
}; };
class SlicePluginCreator : public TensorRTPluginCreator { class SlicePluginCreator : public TensorRTPluginCreator {
...@@ -144,8 +143,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT { ...@@ -144,8 +143,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std::vector<int> axes_; std::vector<int> axes_;
int decrease_axis_; int decrease_axis_;
int* offset_temp_data_{nullptr}; int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_; std::vector<int> offset_info_;
cudaStream_t copy_stream_;
}; };
class SlicePluginDynamicCreator : public TensorRTPluginCreator { class SlicePluginDynamicCreator : public TensorRTPluginCreator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册