未验证 提交 749667e5 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute (#45874)

上级 cdda9799
...@@ -303,9 +303,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -303,9 +303,6 @@ bool OpTeller::Tell(const framework::ir::Node* node,
desc.HasAttr("skip_quant")) desc.HasAttr("skip_quant"))
return false; return false;
// do not support Attribute with Variable(s) Type
if (HasUnsupportAttrVar(desc)) return false;
for (auto& teller : tellers_) { for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = { std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid", "relu", "relu6", "sigmoid",
...@@ -364,7 +361,30 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -364,7 +361,30 @@ bool OpTeller::Tell(const framework::ir::Node* node,
} }
} }
if (op_type == "dropout") {
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
* If Attribute is Variable(s), HasAttr() will return False
*/
if (!desc.HasAttr("dropout_prob", /*with_attr_var=*/false)) {
VLOG(3)
<< "Skip to convert into TRT while found Attribute('dropout_prob') "
"is Variable type in dropout.";
return false;
}
}
if (op_type == "pool2d") { if (op_type == "pool2d") {
// If Attribute is Variable(s), HasAttr() will return False
if (!desc.HasAttr("ksize", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('ksize') is "
"Variable type in pool2d.";
return false;
}
std::vector<int> paddings = std::vector<int> paddings =
PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) { if (paddings.size() > 2) {
...@@ -797,6 +817,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -797,6 +817,12 @@ bool OpTeller::Tell(const framework::ir::Node* node,
} }
if (op_type == "arg_max") { if (op_type == "arg_max") {
if (!desc.HasAttr("axis", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('axis') is "
"Variable type in arg_max.";
return false;
}
int axis = desc.HasAttr("axis") int axis = desc.HasAttr("axis")
? PADDLE_GET_CONST(int64_t, desc.GetAttr("axis")) ? PADDLE_GET_CONST(int64_t, desc.GetAttr("axis"))
: -1; : -1;
...@@ -1061,6 +1087,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -1061,6 +1087,13 @@ bool OpTeller::Tell(const framework::ir::Node* node,
} }
if (op_type == "squeeze2") { if (op_type == "squeeze2") {
// If Attribute is Variable(s), HasAttr() will return False
if (!desc.HasAttr("axes", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('axes') is "
"Variable type in squeeze2.";
return false;
}
std::vector<int> axes; std::vector<int> axes;
if (desc.HasAttr("axes")) { if (desc.HasAttr("axes")) {
axes = PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("axes")); axes = PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
...@@ -2002,6 +2035,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -2002,6 +2035,13 @@ bool OpTeller::Tell(const framework::ir::Node* node,
} }
if (op_type == "reduce_sum" || op_type == "reduce_mean") { if (op_type == "reduce_sum" || op_type == "reduce_mean") {
if (!desc.HasAttr("dim", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is "
"Variable type in "
<< desc.Type();
return false;
}
if (!(desc.HasAttr("keep_dim") && desc.HasAttr("dim") && if (!(desc.HasAttr("keep_dim") && desc.HasAttr("dim") &&
desc.HasAttr("reduce_all"))) { desc.HasAttr("reduce_all"))) {
VLOG(3) << "the " << op_type VLOG(3) << "the " << op_type
...@@ -2265,34 +2305,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -2265,34 +2305,6 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false; return false;
} }
bool OpTeller::HasUnsupportAttrVar(const framework::OpDesc& desc) const {
const std::string op_type = desc.Type();
auto has_attr_var = [&](const std::string& attr_name) -> bool {
// If Attribute is Variable(s), HasAttr() will return False
return !desc.HasAttr(attr_name, /*with_attr_var=*/false);
};
std::unordered_map<std::string, std::vector<std::string>> attrs_info = {
{"dropout", {"dropout_prob"}},
{"pool2d", {"ksize"}},
{"arg_max", {"axis"}},
{"reduce_mean", {"dim"}},
{"reduce_sum", {"dim"}},
{"squeeze2", {"axes"}},
};
bool flag = false;
auto iter = attrs_info.find(op_type);
if (iter != attrs_info.end()) {
for (auto& attr_name : iter->second) {
if (has_attr_var(attr_name)) {
flag = true;
break;
}
}
}
return flag;
}
OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); } OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -73,14 +73,6 @@ class OpTeller { ...@@ -73,14 +73,6 @@ class OpTeller {
private: private:
OpTeller(); OpTeller();
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
*/
bool HasUnsupportAttrVar(const framework::OpDesc& desc) const;
private: private:
std::vector<std::unique_ptr<Teller>> tellers_; std::vector<std::unique_ptr<Teller>> tellers_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册