diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index cbd6e3a2e4ffe536cc818e0c1e82b010fe619411..2b6541c5515cec6a1cea4355af28b8cde3ca5d95 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -65,6 +65,7 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { DeserializeValue(&serial_data, &serial_length, &starts_); DeserializeValue(&serial_data, &serial_length, &ends_); DeserializeValue(&serial_data, &serial_length, &axes_); + DeserializeValue(&serial_data, &serial_length, &with_fp16_); cudaEventCreate(©_event_); cudaStreamCreate(©_stream_); } @@ -187,17 +188,17 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, } size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT { - return getBaseSerializationSize() + SerializedSize(getPluginType()) + - SerializedSize(starts_) + SerializedSize(ends_) + - SerializedSize(axes_); + return getBaseSerializationSize() + SerializedSize(starts_) + + SerializedSize(ends_) + SerializedSize(axes_) + + SerializedSize(with_fp16_); } void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { - SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, starts_); SerializeValue(&buffer, ends_); SerializeValue(&buffer, axes_); + SerializeValue(&buffer, with_fp16_); } // Dynamic Plugin below. diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py index 6ea2335c7a1b1c456a2adf2d60ddcffeb59d9367..98232838ee08b454540d4efad71ad6611a00764c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py @@ -86,5 +86,19 @@ class SlicePluginTRTTestFp16(SlicePluginTRTTest): self.enable_trt = True +class StaticSlicePluginTRTTestFp16(SlicePluginTRTTest): + def setUpTensorRTParams(self): + self.trt_parameters = SlicePluginTRTTest.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Half, True, False) + self.enable_trt = True + + +class StaticSlicePluginTRTTestFp32(SlicePluginTRTTest): + def setUpTensorRTParams(self): + self.trt_parameters = SlicePluginTRTTest.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, True, False) + self.enable_trt = True + + if __name__ == "__main__": unittest.main()