未验证 提交 749667e5 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute (#45874)

上级 cdda9799
......@@ -303,9 +303,6 @@ bool OpTeller::Tell(const framework::ir::Node* node,
desc.HasAttr("skip_quant"))
return false;
// do not support Attribute with Variable(s) Type
if (HasUnsupportAttrVar(desc)) return false;
for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid",
......@@ -364,7 +361,30 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
}
if (op_type == "dropout") {
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
* If Attribute is Variable(s), HasAttr() will return False
*/
if (!desc.HasAttr("dropout_prob", /*with_attr_var=*/false)) {
VLOG(3)
<< "Skip to convert into TRT while found Attribute('dropout_prob') "
"is Variable type in dropout.";
return false;
}
}
if (op_type == "pool2d") {
// If Attribute is Variable(s), HasAttr() will return False
if (!desc.HasAttr("ksize", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('ksize') is "
"Variable type in pool2d.";
return false;
}
std::vector<int> paddings =
PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) {
......@@ -797,6 +817,12 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
if (op_type == "arg_max") {
if (!desc.HasAttr("axis", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('axis') is "
"Variable type in arg_max.";
return false;
}
int axis = desc.HasAttr("axis")
? PADDLE_GET_CONST(int64_t, desc.GetAttr("axis"))
: -1;
......@@ -1061,6 +1087,13 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
if (op_type == "squeeze2") {
// If Attribute is Variable(s), HasAttr() will return False
if (!desc.HasAttr("axes", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('axes') is "
"Variable type in squeeze2.";
return false;
}
std::vector<int> axes;
if (desc.HasAttr("axes")) {
axes = PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
......@@ -2002,6 +2035,13 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
if (op_type == "reduce_sum" || op_type == "reduce_mean") {
if (!desc.HasAttr("dim", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is "
"Variable type in "
<< desc.Type();
return false;
}
if (!(desc.HasAttr("keep_dim") && desc.HasAttr("dim") &&
desc.HasAttr("reduce_all"))) {
VLOG(3) << "the " << op_type
......@@ -2265,34 +2305,6 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false;
}
bool OpTeller::HasUnsupportAttrVar(const framework::OpDesc& desc) const {
const std::string op_type = desc.Type();
auto has_attr_var = [&](const std::string& attr_name) -> bool {
// If Attribute is Variable(s), HasAttr() will return False
return !desc.HasAttr(attr_name, /*with_attr_var=*/false);
};
std::unordered_map<std::string, std::vector<std::string>> attrs_info = {
{"dropout", {"dropout_prob"}},
{"pool2d", {"ksize"}},
{"arg_max", {"axis"}},
{"reduce_mean", {"dim"}},
{"reduce_sum", {"dim"}},
{"squeeze2", {"axes"}},
};
bool flag = false;
auto iter = attrs_info.find(op_type);
if (iter != attrs_info.end()) {
for (auto& attr_name : iter->second) {
if (has_attr_var(attr_name)) {
flag = true;
break;
}
}
}
return flag;
}
OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
} // namespace tensorrt
} // namespace inference
......
......@@ -73,14 +73,6 @@ class OpTeller {
private:
OpTeller();
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
*/
bool HasUnsupportAttrVar(const framework::OpDesc& desc) const;
private:
std::vector<std::unique_ptr<Teller>> tellers_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册