未验证 提交 4b4d92ea 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add Hasattri check of op teller. (#50110)

* add_hasattri_check

* add_hasattri_check
上级 d994d212
......@@ -307,6 +307,9 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << "Deformable conv trt plugin does not support dynamic shape";
return false;
}
if (!desc.HasAttr("groups") || !desc.HasAttr("strides") ||
!desc.HasAttr("paddings"))
return false;
auto* block = desc.Block();
auto input_name = desc.Input("Input")[0];
auto* input_desc = block->FindVar(input_name);
......@@ -450,8 +453,10 @@ struct SimpleOpTypeSetTeller : public Teller {
const auto x_shape = x_var_desc->GetShape();
}
if (op_type == "group_norm") {
bool has_attrs = (desc.HasAttr("epsilon") && desc.HasAttr("groups"));
if (has_attrs == false) return false;
if (!desc.HasAttr("epsilon") || !desc.HasAttr("groups") ||
!desc.HasAttr("data_layout"))
return false;
auto registry = GetPluginRegistry();
if (registry == nullptr) return false;
std::string layout_str =
......@@ -534,6 +539,9 @@ struct SimpleOpTypeSetTeller : public Teller {
}
if (op_type == "flatten_contiguous_range") {
if (!with_dynamic_shape) {
if (!desc.HasAttr("start_axis") || !desc.HasAttr("stop_axis")) {
return false;
}
int start_axis = PADDLE_GET_CONST(int, desc.GetAttr("start_axis"));
int stop_axis = PADDLE_GET_CONST(int, desc.GetAttr("stop_axis"));
auto x_var_name = desc.Input("X")[0];
......@@ -1250,7 +1258,9 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << "the fill_any_like does not support static shape yet";
return false;
}
int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype"));
int dtype = desc.HasAttr("dtype")
? PADDLE_GET_CONST(int, desc.GetAttr("dtype"))
: -1;
auto* block = desc.Block();
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto input_type = x_var_desc->GetDataType();
......@@ -1601,7 +1611,9 @@ struct SimpleOpTypeSetTeller : public Teller {
fill_constant_inputs.end()) {
if (desc.Input("ShapeTensorList").size()) return false;
}
int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype"));
int dtype = desc.HasAttr("dtype")
? PADDLE_GET_CONST(int, desc.GetAttr("dtype"))
: 5;
// only support int32, int64, float32
if (!(dtype == 2 || dtype == 3 || dtype == 5)) {
return false;
......@@ -1662,6 +1674,7 @@ struct SimpleOpTypeSetTeller : public Teller {
}
if (op_type == "pad") {
if (!desc.HasAttr("pad_value") || !desc.HasAttr("paddings")) return false;
const float pad_value =
PADDLE_GET_CONST(float, desc.GetAttr("pad_value"));
if (pad_value != 0.0f) {
......@@ -2257,6 +2270,7 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
#endif
std::vector<int> paddings =
PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
......@@ -2393,6 +2407,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"static shape yet";
return false;
}
if (!desc.HasAttr("axis")) {
return false;
}
int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis"));
if (axis == 0) {
return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册