未验证 提交 b779d2b8 编写于 作者: W Wilber 提交者: GitHub

fix slice plugin (#43110)

上级 12d8a567
......@@ -56,8 +56,6 @@ SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes) {
with_fp16_ = with_fp16;
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
}
SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
......@@ -66,15 +64,10 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &ends_);
DeserializeValue(&serial_data, &serial_length, &axes_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
DeserializeValue(&serial_data, &serial_length, &offset_info_);
}
SlicePlugin::~SlicePlugin() {
cudaStreamDestroy(copy_stream_);
cudaEventDestroy(copy_event_);
cudaFree(offset_temp_data_);
}
SlicePlugin::~SlicePlugin() { cudaFree(offset_temp_data_); }
SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
return new SlicePlugin(starts_, ends_, axes_, with_fp16_);
......@@ -159,11 +152,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
}
cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice,
copy_stream_);
cudaEventRecord(copy_event_, copy_stream_);
cudaStreamWaitEvent(stream, copy_event_, 0);
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream);
int threads = 256;
int blocks = (out_num + threads - 1) / threads;
......@@ -190,7 +179,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT {
return getBaseSerializationSize() + SerializedSize(starts_) +
SerializedSize(ends_) + SerializedSize(axes_) +
SerializedSize(with_fp16_);
SerializedSize(with_fp16_) + SerializedSize(offset_info_);
}
void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
......@@ -199,6 +188,7 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, ends_);
SerializeValue(&buffer, axes_);
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, offset_info_);
}
// Dynamic Plugin below.
......@@ -209,8 +199,6 @@ SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) {
with_fp16_ = with_fp16;
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
}
SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
......@@ -220,13 +208,10 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue(&serialData, &serialLength, &axes_);
DeserializeValue(&serialData, &serialLength, &decrease_axis_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
DeserializeValue(&serialData, &serialLength, &offset_info_);
}
void SlicePluginDynamic::destroy() TRT_NOEXCEPT {
cudaStreamDestroy(copy_stream_);
cudaEventDestroy(copy_event_);
cudaFree(offset_temp_data_);
delete this;
}
......@@ -236,7 +221,7 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_) + SerializedSize(decrease_axis_) +
SerializedSize(with_fp16_);
SerializedSize(with_fp16_) + SerializedSize(offset_info_);
return size;
}
......@@ -247,6 +232,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, axes_);
SerializeValue(&buffer, decrease_axis_);
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, offset_info_);
}
nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
......@@ -361,23 +347,19 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
offsets[axes_[i]] = starts_[i];
}
std::vector<int> offset_info;
offset_info_.resize(num_dims * 3);
for (size_t i = 0; i < num_dims; ++i) {
offset_info.push_back(offsets[i]);
offset_info.push_back(extends[i]);
offset_info.push_back(seg_offsets[i]);
offset_info_[i * 3 + 0] = offsets[i];
offset_info_[i * 3 + 1] = extends[i];
offset_info_[i * 3 + 2] = seg_offsets[i];
}
if (offset_temp_data_ == nullptr) {
cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
}
cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice,
copy_stream_);
cudaEventRecord(copy_event_, copy_stream_);
cudaStreamWaitEvent(stream, copy_event_, 0);
cudaMemcpyAsync(offset_temp_data_, offset_info_.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream);
int threads = 256;
int blocks = (out_num + threads - 1) / threads;
......
......@@ -64,8 +64,7 @@ class SlicePlugin : public PluginTensorRT {
std::vector<int> ends_;
std::vector<int> axes_;
int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_;
cudaStream_t copy_stream_;
std::vector<int> offset_info_;
};
class SlicePluginCreator : public TensorRTPluginCreator {
......@@ -144,8 +143,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std::vector<int> axes_;
int decrease_axis_;
int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_;
cudaStream_t copy_stream_;
std::vector<int> offset_info_;
};
class SlicePluginDynamicCreator : public TensorRTPluginCreator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册