未验证 提交 a642365e 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute (#45795)

* [OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute

* fix iterator

* fix typo

* fix lambda expr

* fix ptr
上级 bd4ce23e
...@@ -799,7 +799,11 @@ Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const { ...@@ -799,7 +799,11 @@ Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
HasAttrVar(it->second), HasAttrVar(it->second),
false, false,
platform::errors::NotFound("Attribute %s is not found.", name)); platform::errors::NotFound(
"Attribute %s with constant value is not found, but found it with "
"Variable(s) type, which maybe not supported in some scenarios "
"currently, such as TensorRT et.al",
name));
} }
return it->second; return it->second;
} }
......
...@@ -303,6 +303,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -303,6 +303,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
desc.HasAttr("skip_quant")) desc.HasAttr("skip_quant"))
return false; return false;
// do not support Attribute with Variable(s) Type
if (HasUnsupportAttrVar(desc)) return false;
for (auto& teller : tellers_) { for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = { std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid", "relu", "relu6", "sigmoid",
...@@ -2261,6 +2264,35 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -2261,6 +2264,35 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false; return false;
} }
bool OpTeller::HasUnsupportAttrVar(const framework::OpDesc& desc) const {
const std::string op_type = desc.Type();
auto has_attr_var = [&](const std::string& attr_name) -> bool {
// If Attribute is Variable(s), HasAttr() will return False
return !desc.HasAttr(attr_name, /*with_attr_var=*/false);
};
std::unordered_map<std::string, std::vector<std::string>> attrs_info = {
{"dropout", {"dropout_prob"}},
{"pool2d", {"ksize"}},
{"arg_max", {"axis"}},
{"reduce_mean", {"dim"}},
{"reduce_sum", {"dim"}},
{"squeeze2", {"axes"}},
};
bool flag = false;
auto iter = attrs_info.find(op_type);
if (iter != attrs_info.end()) {
for (auto& attr_name : iter->second) {
if (has_attr_var(attr_name)) {
flag = true;
break;
}
}
}
return flag;
}
OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); } OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -73,6 +73,14 @@ class OpTeller { ...@@ -73,6 +73,14 @@ class OpTeller {
private: private:
OpTeller(); OpTeller();
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
*/
bool HasUnsupportAttrVar(const framework::OpDesc& desc) const;
private: private:
std::vector<std::unique_ptr<Teller>> tellers_; std::vector<std::unique_ptr<Teller>> tellers_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册