From 4b4d92eae39b1b6a4573b6a6c4964ddcbf2cc802 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Mon, 6 Feb 2023 13:15:02 +0800 Subject: [PATCH] [Paddle Inference] Add Hasattri check of op teller. (#50110) * add_hasattri_check * add_hasattri_check --- paddle/fluid/inference/tensorrt/op_teller.cc | 25 ++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index e9c34408bb..06b2b0348d 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -307,6 +307,9 @@ struct SimpleOpTypeSetTeller : public Teller { VLOG(3) << "Deformable conv trt plugin does not support dynamic shape"; return false; } + if (!desc.HasAttr("groups") || !desc.HasAttr("strides") || + !desc.HasAttr("paddings")) + return false; auto* block = desc.Block(); auto input_name = desc.Input("Input")[0]; auto* input_desc = block->FindVar(input_name); @@ -450,8 +453,10 @@ struct SimpleOpTypeSetTeller : public Teller { const auto x_shape = x_var_desc->GetShape(); } if (op_type == "group_norm") { - bool has_attrs = (desc.HasAttr("epsilon") && desc.HasAttr("groups")); - if (has_attrs == false) return false; + if (!desc.HasAttr("epsilon") || !desc.HasAttr("groups") || + !desc.HasAttr("data_layout")) + return false; + auto registry = GetPluginRegistry(); if (registry == nullptr) return false; std::string layout_str = @@ -534,6 +539,9 @@ struct SimpleOpTypeSetTeller : public Teller { } if (op_type == "flatten_contiguous_range") { if (!with_dynamic_shape) { + if (!desc.HasAttr("start_axis") || !desc.HasAttr("stop_axis")) { + return false; + } int start_axis = PADDLE_GET_CONST(int, desc.GetAttr("start_axis")); int stop_axis = PADDLE_GET_CONST(int, desc.GetAttr("stop_axis")); auto x_var_name = desc.Input("X")[0]; @@ -1250,7 +1258,9 @@ struct SimpleOpTypeSetTeller : public Teller { VLOG(3) << "the fill_any_like does not support static shape yet"; return false; } - int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); + int dtype = desc.HasAttr("dtype") + ? PADDLE_GET_CONST(int, desc.GetAttr("dtype")) + : -1; auto* block = desc.Block(); auto* x_var_desc = block->FindVar(desc.Input("X")[0]); auto input_type = x_var_desc->GetDataType(); @@ -1601,7 +1611,9 @@ struct SimpleOpTypeSetTeller : public Teller { fill_constant_inputs.end()) { if (desc.Input("ShapeTensorList").size()) return false; } - int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); + int dtype = desc.HasAttr("dtype") + ? PADDLE_GET_CONST(int, desc.GetAttr("dtype")) + : 5; // only support int32, int64, float32 if (!(dtype == 2 || dtype == 3 || dtype == 5)) { return false; @@ -1662,6 +1674,7 @@ struct SimpleOpTypeSetTeller : public Teller { } if (op_type == "pad") { + if (!desc.HasAttr("pad_value") || !desc.HasAttr("paddings")) return false; const float pad_value = PADDLE_GET_CONST(float, desc.GetAttr("pad_value")); if (pad_value != 0.0f) { @@ -2257,6 +2270,7 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } #endif + std::vector paddings = PADDLE_GET_CONST(std::vector, desc.GetAttr("paddings")); @@ -2393,6 +2407,9 @@ struct SimpleOpTypeSetTeller : public Teller { "static shape yet"; return false; } + if (!desc.HasAttr("axis")) { + return false; + } int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis")); if (axis == 0) { return false; -- GitLab