From c7c5635e0752e3cdc4aa4cc50c10294cf154000e Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 25 Apr 2023 13:24:10 +0800 Subject: [PATCH] [Paddle-TRT] The Graph uses OpConverterType for op converter (#53214) * add ```converter_type``` for op converter --- paddle/fluid/inference/tensorrt/convert/op_converter.h | 5 +++-- .../tensorrt/convert/test_custom_plugin_creater.cc | 4 ++-- .../inference/tensorrt/convert/test_op_converter.cc | 2 +- paddle/fluid/inference/tensorrt/op_teller.cc | 6 +++--- paddle/fluid/inference/tensorrt/op_teller.h | 9 ++------- .../fluid/operators/tensorrt/tensorrt_engine_op_test.cc | 7 +++++-- 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index ee8cc0c8681..87ad887cef3 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -56,8 +56,9 @@ class OpConverter { OpConverter* it{nullptr}; - auto op_converter_type_map = OpTeller::Global().GetOpConverterTypeMap(); - switch (op_converter_type_map.at(op_desc.Type())) { + auto converter_type = static_cast( + PADDLE_GET_CONST(int, op_desc.GetAttr("converter_type"))); + switch (converter_type) { case OpConverterType::Default: if (op_desc.Type().find("elementwise") != std::string::npos) { static std::unordered_set add_tensor_op_set{ diff --git a/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc b/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc index 47c3793ab9f..eee4d0c12ed 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc @@ -109,7 +109,7 @@ TEST(CustomPluginCreater, StaticShapePlugin) { framework::OpDesc custom_op(*op_desc, nullptr); CHECK_EQ((*custom_plugin_tell)(custom_op, false, false), true); - OpTeller::Global().SetOpConverterType("custom_op", + OpTeller::Global().SetOpConverterType(&custom_op, OpConverterType::CustomPluginCreater); OpConverter converter; @@ -196,7 +196,7 @@ TEST(CustomPluginCreater, DynamicShapePlugin) { framework::OpDesc custom_op(*op_desc, nullptr); CHECK_EQ((*custom_plugin_tell)(custom_op, false, true), true); - OpTeller::Global().SetOpConverterType("custom_op", + OpTeller::Global().SetOpConverterType(&custom_op, OpConverterType::CustomPluginCreater); OpConverter converter; diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index ee45b602b3a..3a15af255e5 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -57,7 +57,7 @@ TEST(OpConverter, ConvertBlock) { x_tensor->Resize(phi::make_ddim(dim_vec)); x_tensor->mutable_data(platform::CUDAPlace(0)); - OpTeller::Global().SetOpConverterType("conv2d", OpConverterType::Default); + OpTeller::Global().SetOpConverterType(conv2d_op, OpConverterType::Default); OpConverter converter; converter.ConvertBlock( *block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index b17aca9e8cb..78e300a8d73 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -3080,17 +3080,17 @@ bool OpTeller::Tell(const framework::ir::Node* node, return false; auto& default_teller = GetDefaultTeller(); if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { - SetOpConverterType(op_type, OpConverterType::Default); + SetOpConverterType(node->Op(), OpConverterType::Default); return true; } auto& generic_plugin_teller = GetGenericPluginTeller(); if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { - SetOpConverterType(op_type, OpConverterType::GenericPluginCreater); + SetOpConverterType(node->Op(), OpConverterType::GenericPluginCreater); return true; } auto& custom_plugin_teller = GetCustomPluginTeller(); if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { - SetOpConverterType(op_type, OpConverterType::CustomPluginCreater); + SetOpConverterType(node->Op(), OpConverterType::CustomPluginCreater); return true; } return false; diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 2fa3dc36121..cb879cb5f9f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -82,12 +82,8 @@ class OpTeller { std::unique_ptr& GetCustomPluginTeller() { return tellers_.at(2); } - void SetOpConverterType(std::string name, OpConverterType type) { - op_converter_type_map_[name] = type; - } - - const std::map& GetOpConverterTypeMap() const { - return op_converter_type_map_; + void SetOpConverterType(framework::OpDesc* op_desc, OpConverterType type) { + op_desc->SetAttr("converter_type", static_cast(type)); } private: @@ -95,7 +91,6 @@ class OpTeller { private: std::vector> tellers_; - std::map op_converter_type_map_; }; } // namespace tensorrt diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index 5a9fa7241e8..3d96361d89f 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -94,6 +94,11 @@ void DynamicShapeTest(bool allow_build_at_runtime) { "Out", std::vector({"z0"})); // 2 x 4 x 4 x 4 elementwise_add1->SetAttr("axis", static_cast(0)); + inference::tensorrt::OpTeller::Global().SetOpConverterType( + elementwise_add0, inference::tensorrt::OpConverterType::Default); + inference::tensorrt::OpTeller::Global().SetOpConverterType( + elementwise_add1, inference::tensorrt::OpConverterType::Default); + // Set inputs' variable shape in BlockDesc // the batch size is 2, so the dims of 'x' is {2, 4} AddTensorToBlockDesc(block_, "x", std::vector({2, 4, 4, 4})); @@ -170,8 +175,6 @@ void DynamicShapeTest(bool allow_build_at_runtime) { // Execute them. LOG(INFO) << "engine_op run"; - inference::tensorrt::OpTeller::Global().SetOpConverterType( - "elementwise_add", inference::tensorrt::OpConverterType::Default); engine_op->Run(scope, place); } -- GitLab