diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index ee8cc0c8681c0cfdffa79ba4ac083babbfcaf96f..87ad887cef38382baef56710d024dc2b4a29891e 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -56,8 +56,9 @@ class OpConverter { OpConverter* it{nullptr}; - auto op_converter_type_map = OpTeller::Global().GetOpConverterTypeMap(); - switch (op_converter_type_map.at(op_desc.Type())) { + auto converter_type = static_cast( + PADDLE_GET_CONST(int, op_desc.GetAttr("converter_type"))); + switch (converter_type) { case OpConverterType::Default: if (op_desc.Type().find("elementwise") != std::string::npos) { static std::unordered_set add_tensor_op_set{ diff --git a/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc b/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc index 47c3793ab9f0b4d2795e7175ffee1a8221b07eeb..eee4d0c12edbe87603798ebf370d32c5ec8659c7 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc @@ -109,7 +109,7 @@ TEST(CustomPluginCreater, StaticShapePlugin) { framework::OpDesc custom_op(*op_desc, nullptr); CHECK_EQ((*custom_plugin_tell)(custom_op, false, false), true); - OpTeller::Global().SetOpConverterType("custom_op", + OpTeller::Global().SetOpConverterType(&custom_op, OpConverterType::CustomPluginCreater); OpConverter converter; @@ -196,7 +196,7 @@ TEST(CustomPluginCreater, DynamicShapePlugin) { framework::OpDesc custom_op(*op_desc, nullptr); CHECK_EQ((*custom_plugin_tell)(custom_op, false, true), true); - OpTeller::Global().SetOpConverterType("custom_op", + OpTeller::Global().SetOpConverterType(&custom_op, OpConverterType::CustomPluginCreater); OpConverter converter; diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index ee45b602b3ab9a4c9b39ca10b38a66ff80823b02..3a15af255e5bce4118a23120b1c5c9293817ed65 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -57,7 +57,7 @@ TEST(OpConverter, ConvertBlock) { x_tensor->Resize(phi::make_ddim(dim_vec)); x_tensor->mutable_data(platform::CUDAPlace(0)); - OpTeller::Global().SetOpConverterType("conv2d", OpConverterType::Default); + OpTeller::Global().SetOpConverterType(conv2d_op, OpConverterType::Default); OpConverter converter; converter.ConvertBlock( *block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index b17aca9e8cb4da8df8e249a2a51d35b03188996e..78e300a8d730d58357b13dc8b0f54fd772086452 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -3080,17 +3080,17 @@ bool OpTeller::Tell(const framework::ir::Node* node, return false; auto& default_teller = GetDefaultTeller(); if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { - SetOpConverterType(op_type, OpConverterType::Default); + SetOpConverterType(node->Op(), OpConverterType::Default); return true; } auto& generic_plugin_teller = GetGenericPluginTeller(); if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { - SetOpConverterType(op_type, OpConverterType::GenericPluginCreater); + SetOpConverterType(node->Op(), OpConverterType::GenericPluginCreater); return true; } auto& custom_plugin_teller = GetCustomPluginTeller(); if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { - SetOpConverterType(op_type, OpConverterType::CustomPluginCreater); + SetOpConverterType(node->Op(), OpConverterType::CustomPluginCreater); return true; } return false; diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 2fa3dc361217edb9f358d2e0acd76604fc7a6e91..cb879cb5f9f61cc02b37b417c5165f41c3c27f2e 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -82,12 +82,8 @@ class OpTeller { std::unique_ptr& GetCustomPluginTeller() { return tellers_.at(2); } - void SetOpConverterType(std::string name, OpConverterType type) { - op_converter_type_map_[name] = type; - } - - const std::map& GetOpConverterTypeMap() const { - return op_converter_type_map_; + void SetOpConverterType(framework::OpDesc* op_desc, OpConverterType type) { + op_desc->SetAttr("converter_type", static_cast(type)); } private: @@ -95,7 +91,6 @@ class OpTeller { private: std::vector> tellers_; - std::map op_converter_type_map_; }; } // namespace tensorrt diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index 5a9fa7241e853e98ca1d64a261ca6c7f9e2a8096..3d96361d89f043da5a6acff7ad1ed96e20bac583 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -94,6 +94,11 @@ void DynamicShapeTest(bool allow_build_at_runtime) { "Out", std::vector({"z0"})); // 2 x 4 x 4 x 4 elementwise_add1->SetAttr("axis", static_cast(0)); + inference::tensorrt::OpTeller::Global().SetOpConverterType( + elementwise_add0, inference::tensorrt::OpConverterType::Default); + inference::tensorrt::OpTeller::Global().SetOpConverterType( + elementwise_add1, inference::tensorrt::OpConverterType::Default); + // Set inputs' variable shape in BlockDesc // the batch size is 2, so the dims of 'x' is {2, 4} AddTensorToBlockDesc(block_, "x", std::vector({2, 4, 4, 4})); @@ -170,8 +175,6 @@ void DynamicShapeTest(bool allow_build_at_runtime) { // Execute them. LOG(INFO) << "engine_op run"; - inference::tensorrt::OpTeller::Global().SetOpConverterType( - "elementwise_add", inference::tensorrt::OpConverterType::Default); engine_op->Run(scope, place); }