未验证 提交 a642365e 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute (#45795)

* [OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute

* fix iterator

* fix typo

* fix lambda expr

* fix ptr
上级 bd4ce23e
......@@ -799,7 +799,11 @@ Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
PADDLE_ENFORCE_EQ(
HasAttrVar(it->second),
false,
platform::errors::NotFound("Attribute %s is not found.", name));
platform::errors::NotFound(
"Attribute %s with constant value is not found, but found it with "
"Variable(s) type, which maybe not supported in some scenarios "
"currently, such as TensorRT et.al",
name));
}
return it->second;
}
......
......@@ -303,6 +303,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
desc.HasAttr("skip_quant"))
return false;
// do not support Attribute with Variable(s) Type
if (HasUnsupportAttrVar(desc)) return false;
for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid",
......@@ -2261,6 +2264,35 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false;
}
bool OpTeller::HasUnsupportAttrVar(const framework::OpDesc& desc) const {
const std::string op_type = desc.Type();
auto has_attr_var = [&](const std::string& attr_name) -> bool {
// If Attribute is Variable(s), HasAttr() will return False
return !desc.HasAttr(attr_name, /*with_attr_var=*/false);
};
std::unordered_map<std::string, std::vector<std::string>> attrs_info = {
{"dropout", {"dropout_prob"}},
{"pool2d", {"ksize"}},
{"arg_max", {"axis"}},
{"reduce_mean", {"dim"}},
{"reduce_sum", {"dim"}},
{"squeeze2", {"axes"}},
};
bool flag = false;
auto iter = attrs_info.find(op_type);
if (iter != attrs_info.end()) {
for (auto& attr_name : iter->second) {
if (has_attr_var(attr_name)) {
flag = true;
break;
}
}
}
return flag;
}
OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
} // namespace tensorrt
} // namespace inference
......
......@@ -73,6 +73,14 @@ class OpTeller {
private:
OpTeller();
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
*/
bool HasUnsupportAttrVar(const framework::OpDesc& desc) const;
private:
std::vector<std::unique_ptr<Teller>> tellers_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册