From 50ac7dbfd05e9a5fd8bf8a87111faa1f33590f67 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Fri, 12 Mar 2021 18:05:38 +0800 Subject: [PATCH] Trt elementwise plugin serialize (#31587) * add serialize unittest * fix element_op trt plugin serialize bug --- .../tensorrt/plugin/elementwise_op_plugin.cu | 9 +++- .../tensorrt/plugin/elementwise_op_plugin.h | 47 ++++++++++++++++- .../ir/inference/test_trt_subgraph_pass.py | 52 +++++++++++++++++++ 3 files changed, 105 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu index 457d9dd8737..cc17f8aa248 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu @@ -152,9 +152,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs, int ElementwisePluginDynamic::initialize() { return 0; } -size_t ElementwisePluginDynamic::getSerializationSize() const { return 0; } +size_t ElementwisePluginDynamic::getSerializationSize() const { + return SerializedSize(type_.c_str()) + SerializedSize(axis_); +} -void ElementwisePluginDynamic::serialize(void *buffer) const {} +void ElementwisePluginDynamic::serialize(void *buffer) const { + SerializeValue(&buffer, type_.c_str()); + SerializeValue(&buffer, axis_); +} nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h index e37511868d8..49212aae9aa 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h @@ -92,7 +92,12 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT { public: explicit ElementwisePluginDynamic(const std::string& type, int axis) : type_(type), axis_(axis) {} - ElementwisePluginDynamic(void const* serialData, size_t serialLength) {} + ElementwisePluginDynamic(void const* serialData, size_t serialLength) { + const char* elementwise_type; + DeserializeValue(&serialData, &serialLength, &elementwise_type); + type_ = std::string(elementwise_type); + DeserializeValue(&serialData, &serialLength, &axis_); + } nvinfer1::IPluginV2DynamicExt* clone() const override { return new ElementwisePluginDynamic(type_, axis_); } @@ -138,6 +143,46 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT { std::string type_; int axis_; }; + +class ElementwisePluginV2Creator : public nvinfer1::IPluginCreator { + public: + ElementwisePluginV2Creator() {} + const char* getPluginName() const override { return "elementwise_plugin"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override { + auto plugin = new ElementwisePluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; + +REGISTER_TRT_PLUGIN_V2(ElementwisePluginV2Creator); #endif } // namespace plugin diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py index 2c77ce17231..bdcdeee8dcb 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -414,6 +414,58 @@ class TensorRTSubgraphPassElementwiseMulTest( return fluid.layers.elementwise_mul(x=data1, y=data2) +class TensorRTSubgraphPassElementwiseSerializeTest( + TensorRTSubgraphPassElementwiseTest): + def setUp(self): + super(TensorRTSubgraphPassElementwiseSerializeTest, self).setUp() + self.trt_parameters = TensorRTSubgraphPassElementwiseTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False) + + def test_check_output(self): + if os.path.exists(self.path + "_opt_cache"): + shutil.rmtree(self.path + "_opt_cache") + super(TensorRTSubgraphPassElementwiseSerializeTest, + self).test_check_output() + + +class TensorRTSubgraphPassElementwiseBroadcastDynamicTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data1 = fluid.data( + name="data1", shape=[-1, 3, 64, 64], dtype="float32") + data2 = fluid.data(name="data2", shape=[64, 64], dtype="float32") + eltwise_out = self.append_eltwise(data1, data2) + out = fluid.layers.batch_norm(eltwise_out, is_test=True) + self.feeds = { + "data1": np.random.random([1, 3, 64, 64]).astype("float32"), + "data2": np.random.random([64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False) + self.dynamic_shape_params = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.DynamicShapeParam( + { + 'data1': [1, 3, 8, 64], + 'data2': [8, 64] + }, {'data1': [1, 3, 512, 64], + 'data2': + [512, 64]}, {'data1': [1, 3, 256, 64], + 'data2': [256, 64]}, False) + self.fetch_list = [out] + + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_add(x=data1, y=data2) + + def test_check_output(self): + if os.path.exists(self.path + "_opt_cache"): + shutil.rmtree(self.path + "_opt_cache") + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + class TensorRTSubgraphPassShuffleChannelTest(InferencePassTest): def setUp(self): with fluid.program_guard(self.main_program, self.startup_program): -- GitLab