From a642365e8c7c121d44fc0c162d7bd0aa2663ff01 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 8 Sep 2022 14:00:57 +0800 Subject: [PATCH] [OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute (#45795) * [OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute * fix iterator * fix typo * fix lambda expr * fix ptr --- paddle/fluid/framework/op_desc.cc | 6 +++- paddle/fluid/inference/tensorrt/op_teller.cc | 32 ++++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.h | 8 +++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index fca4ff253d..4d0d10c783 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -799,7 +799,11 @@ Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const { PADDLE_ENFORCE_EQ( HasAttrVar(it->second), false, - platform::errors::NotFound("Attribute %s is not found.", name)); + platform::errors::NotFound( + "Attribute %s with constant value is not found, but found it with " + "Variable(s) type, which maybe not supported in some scenarios " + "currently, such as TensorRT et.al", + name)); } return it->second; } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6286010a03..040a3c9a22 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -303,6 +303,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, desc.HasAttr("skip_quant")) return false; + // do not support Attribute with Variable(s) Type + if (HasUnsupportAttrVar(desc)) return false; + for (auto& teller : tellers_) { std::unordered_set act_op_list = { "relu", "relu6", "sigmoid", @@ -2261,6 +2264,35 @@ bool OpTeller::Tell(const framework::ir::Node* node, return false; } + +bool OpTeller::HasUnsupportAttrVar(const framework::OpDesc& desc) const { + const std::string op_type = desc.Type(); + auto has_attr_var = [&](const std::string& attr_name) -> bool { + // If Attribute is Variable(s), HasAttr() will return False + return !desc.HasAttr(attr_name, /*with_attr_var=*/false); + }; + std::unordered_map> attrs_info = { + {"dropout", {"dropout_prob"}}, + {"pool2d", {"ksize"}}, + {"arg_max", {"axis"}}, + {"reduce_mean", {"dim"}}, + {"reduce_sum", {"dim"}}, + {"squeeze2", {"axes"}}, + }; + + bool flag = false; + auto iter = attrs_info.find(op_type); + if (iter != attrs_info.end()) { + for (auto& attr_name : iter->second) { + if (has_attr_var(attr_name)) { + flag = true; + break; + } + } + } + return flag; +} + OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); } } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 1a6ce092a1..d7d7d8b674 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -73,6 +73,14 @@ class OpTeller { private: OpTeller(); + /* + * Some OpDescs Attribute support both constant value and dynamic + * runtime value (which is a Variable(s) type). But TensorRT maybe + * only support constant value Attribute, so we shall distinguish + * this case in time and return False in OpTeller.Tell(). + */ + bool HasUnsupportAttrVar(const framework::OpDesc& desc) const; + private: std::vector> tellers_; }; -- GitLab