未验证 提交 4b4d92ea 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add Hasattri check of op teller. (#50110)

* add_hasattri_check

* add_hasattri_check
上级 d994d212
...@@ -307,6 +307,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -307,6 +307,9 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << "Deformable conv trt plugin does not support dynamic shape"; VLOG(3) << "Deformable conv trt plugin does not support dynamic shape";
return false; return false;
} }
if (!desc.HasAttr("groups") || !desc.HasAttr("strides") ||
!desc.HasAttr("paddings"))
return false;
auto* block = desc.Block(); auto* block = desc.Block();
auto input_name = desc.Input("Input")[0]; auto input_name = desc.Input("Input")[0];
auto* input_desc = block->FindVar(input_name); auto* input_desc = block->FindVar(input_name);
...@@ -450,8 +453,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -450,8 +453,10 @@ struct SimpleOpTypeSetTeller : public Teller {
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
} }
if (op_type == "group_norm") { if (op_type == "group_norm") {
bool has_attrs = (desc.HasAttr("epsilon") && desc.HasAttr("groups")); if (!desc.HasAttr("epsilon") || !desc.HasAttr("groups") ||
if (has_attrs == false) return false; !desc.HasAttr("data_layout"))
return false;
auto registry = GetPluginRegistry(); auto registry = GetPluginRegistry();
if (registry == nullptr) return false; if (registry == nullptr) return false;
std::string layout_str = std::string layout_str =
...@@ -534,6 +539,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -534,6 +539,9 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
if (op_type == "flatten_contiguous_range") { if (op_type == "flatten_contiguous_range") {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
if (!desc.HasAttr("start_axis") || !desc.HasAttr("stop_axis")) {
return false;
}
int start_axis = PADDLE_GET_CONST(int, desc.GetAttr("start_axis")); int start_axis = PADDLE_GET_CONST(int, desc.GetAttr("start_axis"));
int stop_axis = PADDLE_GET_CONST(int, desc.GetAttr("stop_axis")); int stop_axis = PADDLE_GET_CONST(int, desc.GetAttr("stop_axis"));
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
...@@ -1250,7 +1258,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1250,7 +1258,9 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << "the fill_any_like does not support static shape yet"; VLOG(3) << "the fill_any_like does not support static shape yet";
return false; return false;
} }
int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); int dtype = desc.HasAttr("dtype")
? PADDLE_GET_CONST(int, desc.GetAttr("dtype"))
: -1;
auto* block = desc.Block(); auto* block = desc.Block();
auto* x_var_desc = block->FindVar(desc.Input("X")[0]); auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto input_type = x_var_desc->GetDataType(); auto input_type = x_var_desc->GetDataType();
...@@ -1601,7 +1611,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1601,7 +1611,9 @@ struct SimpleOpTypeSetTeller : public Teller {
fill_constant_inputs.end()) { fill_constant_inputs.end()) {
if (desc.Input("ShapeTensorList").size()) return false; if (desc.Input("ShapeTensorList").size()) return false;
} }
int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); int dtype = desc.HasAttr("dtype")
? PADDLE_GET_CONST(int, desc.GetAttr("dtype"))
: 5;
// only support int32, int64, float32 // only support int32, int64, float32
if (!(dtype == 2 || dtype == 3 || dtype == 5)) { if (!(dtype == 2 || dtype == 3 || dtype == 5)) {
return false; return false;
...@@ -1662,6 +1674,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1662,6 +1674,7 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
if (op_type == "pad") { if (op_type == "pad") {
if (!desc.HasAttr("pad_value") || !desc.HasAttr("paddings")) return false;
const float pad_value = const float pad_value =
PADDLE_GET_CONST(float, desc.GetAttr("pad_value")); PADDLE_GET_CONST(float, desc.GetAttr("pad_value"));
if (pad_value != 0.0f) { if (pad_value != 0.0f) {
...@@ -2257,6 +2270,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2257,6 +2270,7 @@ struct SimpleOpTypeSetTeller : public Teller {
return false; return false;
} }
#endif #endif
std::vector<int> paddings = std::vector<int> paddings =
PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
...@@ -2393,6 +2407,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2393,6 +2407,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"static shape yet"; "static shape yet";
return false; return false;
} }
if (!desc.HasAttr("axis")) {
return false;
}
int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis")); int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis"));
if (axis == 0) { if (axis == 0) {
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册