From 5e880840b8d8d7c3f0b1db74d265986b51e61d30 Mon Sep 17 00:00:00 2001 From: wenbin Date: Fri, 22 Oct 2021 16:58:51 +0800 Subject: [PATCH] correct slice serialize data (#36588) * slice * add UT --- .../inference/tensorrt/plugin/slice_op_plugin.cu | 9 +++++---- .../ir/inference/test_trt_slice_plugin.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index cbd6e3a2e4f..2b6541c5515 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -65,6 +65,7 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { DeserializeValue(&serial_data, &serial_length, &starts_); DeserializeValue(&serial_data, &serial_length, &ends_); DeserializeValue(&serial_data, &serial_length, &axes_); + DeserializeValue(&serial_data, &serial_length, &with_fp16_); cudaEventCreate(©_event_); cudaStreamCreate(©_stream_); } @@ -187,17 +188,17 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, } size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT { - return getBaseSerializationSize() + SerializedSize(getPluginType()) + - SerializedSize(starts_) + SerializedSize(ends_) + - SerializedSize(axes_); + return getBaseSerializationSize() + SerializedSize(starts_) + + SerializedSize(ends_) + SerializedSize(axes_) + + SerializedSize(with_fp16_); } void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { - SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, starts_); SerializeValue(&buffer, ends_); SerializeValue(&buffer, axes_); + SerializeValue(&buffer, with_fp16_); } // Dynamic Plugin below. diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py index 6ea2335c7a1..98232838ee0 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py @@ -86,5 +86,19 @@ class SlicePluginTRTTestFp16(SlicePluginTRTTest): self.enable_trt = True +class StaticSlicePluginTRTTestFp16(SlicePluginTRTTest): + def setUpTensorRTParams(self): + self.trt_parameters = SlicePluginTRTTest.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Half, True, False) + self.enable_trt = True + + +class StaticSlicePluginTRTTestFp32(SlicePluginTRTTest): + def setUpTensorRTParams(self): + self.trt_parameters = SlicePluginTRTTest.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, True, False) + self.enable_trt = True + + if __name__ == "__main__": unittest.main() -- GitLab