未验证 提交 c7c5635e 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] The Graph uses OpConverterType for op converter (#53214)

* add ```converter_type``` for op converter
上级 46951224
...@@ -56,8 +56,9 @@ class OpConverter { ...@@ -56,8 +56,9 @@ class OpConverter {
OpConverter* it{nullptr}; OpConverter* it{nullptr};
auto op_converter_type_map = OpTeller::Global().GetOpConverterTypeMap(); auto converter_type = static_cast<OpConverterType>(
switch (op_converter_type_map.at(op_desc.Type())) { PADDLE_GET_CONST(int, op_desc.GetAttr("converter_type")));
switch (converter_type) {
case OpConverterType::Default: case OpConverterType::Default:
if (op_desc.Type().find("elementwise") != std::string::npos) { if (op_desc.Type().find("elementwise") != std::string::npos) {
static std::unordered_set<std::string> add_tensor_op_set{ static std::unordered_set<std::string> add_tensor_op_set{
......
...@@ -109,7 +109,7 @@ TEST(CustomPluginCreater, StaticShapePlugin) { ...@@ -109,7 +109,7 @@ TEST(CustomPluginCreater, StaticShapePlugin) {
framework::OpDesc custom_op(*op_desc, nullptr); framework::OpDesc custom_op(*op_desc, nullptr);
CHECK_EQ((*custom_plugin_tell)(custom_op, false, false), true); CHECK_EQ((*custom_plugin_tell)(custom_op, false, false), true);
OpTeller::Global().SetOpConverterType("custom_op", OpTeller::Global().SetOpConverterType(&custom_op,
OpConverterType::CustomPluginCreater); OpConverterType::CustomPluginCreater);
OpConverter converter; OpConverter converter;
...@@ -196,7 +196,7 @@ TEST(CustomPluginCreater, DynamicShapePlugin) { ...@@ -196,7 +196,7 @@ TEST(CustomPluginCreater, DynamicShapePlugin) {
framework::OpDesc custom_op(*op_desc, nullptr); framework::OpDesc custom_op(*op_desc, nullptr);
CHECK_EQ((*custom_plugin_tell)(custom_op, false, true), true); CHECK_EQ((*custom_plugin_tell)(custom_op, false, true), true);
OpTeller::Global().SetOpConverterType("custom_op", OpTeller::Global().SetOpConverterType(&custom_op,
OpConverterType::CustomPluginCreater); OpConverterType::CustomPluginCreater);
OpConverter converter; OpConverter converter;
......
...@@ -57,7 +57,7 @@ TEST(OpConverter, ConvertBlock) { ...@@ -57,7 +57,7 @@ TEST(OpConverter, ConvertBlock) {
x_tensor->Resize(phi::make_ddim(dim_vec)); x_tensor->Resize(phi::make_ddim(dim_vec));
x_tensor->mutable_data<float>(platform::CUDAPlace(0)); x_tensor->mutable_data<float>(platform::CUDAPlace(0));
OpTeller::Global().SetOpConverterType("conv2d", OpConverterType::Default); OpTeller::Global().SetOpConverterType(conv2d_op, OpConverterType::Default);
OpConverter converter; OpConverter converter;
converter.ConvertBlock( converter.ConvertBlock(
*block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/); *block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/);
......
...@@ -3080,17 +3080,17 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -3080,17 +3080,17 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false; return false;
auto& default_teller = GetDefaultTeller(); auto& default_teller = GetDefaultTeller();
if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::Default); SetOpConverterType(node->Op(), OpConverterType::Default);
return true; return true;
} }
auto& generic_plugin_teller = GetGenericPluginTeller(); auto& generic_plugin_teller = GetGenericPluginTeller();
if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::GenericPluginCreater); SetOpConverterType(node->Op(), OpConverterType::GenericPluginCreater);
return true; return true;
} }
auto& custom_plugin_teller = GetCustomPluginTeller(); auto& custom_plugin_teller = GetCustomPluginTeller();
if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::CustomPluginCreater); SetOpConverterType(node->Op(), OpConverterType::CustomPluginCreater);
return true; return true;
} }
return false; return false;
......
...@@ -82,12 +82,8 @@ class OpTeller { ...@@ -82,12 +82,8 @@ class OpTeller {
std::unique_ptr<Teller>& GetCustomPluginTeller() { return tellers_.at(2); } std::unique_ptr<Teller>& GetCustomPluginTeller() { return tellers_.at(2); }
void SetOpConverterType(std::string name, OpConverterType type) { void SetOpConverterType(framework::OpDesc* op_desc, OpConverterType type) {
op_converter_type_map_[name] = type; op_desc->SetAttr("converter_type", static_cast<int>(type));
}
const std::map<std::string, OpConverterType>& GetOpConverterTypeMap() const {
return op_converter_type_map_;
} }
private: private:
...@@ -95,7 +91,6 @@ class OpTeller { ...@@ -95,7 +91,6 @@ class OpTeller {
private: private:
std::vector<std::unique_ptr<Teller>> tellers_; std::vector<std::unique_ptr<Teller>> tellers_;
std::map<std::string, OpConverterType> op_converter_type_map_;
}; };
} // namespace tensorrt } // namespace tensorrt
......
...@@ -94,6 +94,11 @@ void DynamicShapeTest(bool allow_build_at_runtime) { ...@@ -94,6 +94,11 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
"Out", std::vector<std::string>({"z0"})); // 2 x 4 x 4 x 4 "Out", std::vector<std::string>({"z0"})); // 2 x 4 x 4 x 4
elementwise_add1->SetAttr("axis", static_cast<int32_t>(0)); elementwise_add1->SetAttr("axis", static_cast<int32_t>(0));
inference::tensorrt::OpTeller::Global().SetOpConverterType(
elementwise_add0, inference::tensorrt::OpConverterType::Default);
inference::tensorrt::OpTeller::Global().SetOpConverterType(
elementwise_add1, inference::tensorrt::OpConverterType::Default);
// Set inputs' variable shape in BlockDesc // Set inputs' variable shape in BlockDesc
// the batch size is 2, so the dims of 'x' is {2, 4} // the batch size is 2, so the dims of 'x' is {2, 4}
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4, 4, 4})); AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4, 4, 4}));
...@@ -170,8 +175,6 @@ void DynamicShapeTest(bool allow_build_at_runtime) { ...@@ -170,8 +175,6 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
// Execute them. // Execute them.
LOG(INFO) << "engine_op run"; LOG(INFO) << "engine_op run";
inference::tensorrt::OpTeller::Global().SetOpConverterType(
"elementwise_add", inference::tensorrt::OpConverterType::Default);
engine_op->Run(scope, place); engine_op->Run(scope, place);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册