diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 040a3c9a22af85ecebcdef6862d098bd26cebbf4..957241cb3e89d6fe8c4ab9b083664abd1ca97ed4 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -303,9 +303,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, desc.HasAttr("skip_quant")) return false; - // do not support Attribute with Variable(s) Type - if (HasUnsupportAttrVar(desc)) return false; - for (auto& teller : tellers_) { std::unordered_set act_op_list = { "relu", "relu6", "sigmoid", @@ -364,7 +361,30 @@ bool OpTeller::Tell(const framework::ir::Node* node, } } + if (op_type == "dropout") { + /* + * Some OpDescs Attribute support both constant value and dynamic + * runtime value (which is a Variable(s) type). But TensorRT maybe + * only support constant value Attribute, so we shall distinguish + * this case in time and return False in OpTeller.Tell(). + * If Attribute is Variable(s), HasAttr() will return False + */ + if (!desc.HasAttr("dropout_prob", /*with_attr_var=*/false)) { + VLOG(3) + << "Skip to convert into TRT while found Attribute('dropout_prob') " + "is Variable type in dropout."; + return false; + } + } + if (op_type == "pool2d") { + // If Attribute is Variable(s), HasAttr() will return False + if (!desc.HasAttr("ksize", /*with_attr_var=*/false)) { + VLOG(3) << "Skip to convert into TRT while found Attribute('ksize') is " + "Variable type in pool2d."; + return false; + } + std::vector paddings = PADDLE_GET_CONST(std::vector, desc.GetAttr("paddings")); if (paddings.size() > 2) { @@ -797,6 +817,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, } if (op_type == "arg_max") { + if (!desc.HasAttr("axis", /*with_attr_var=*/false)) { + VLOG(3) << "Skip to convert into TRT while found Attribute('axis') is " + "Variable type in arg_max."; + return false; + } + int axis = desc.HasAttr("axis") ? PADDLE_GET_CONST(int64_t, desc.GetAttr("axis")) : -1; @@ -1061,6 +1087,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, } if (op_type == "squeeze2") { + // If Attribute is Variable(s), HasAttr() will return False + if (!desc.HasAttr("axes", /*with_attr_var=*/false)) { + VLOG(3) << "Skip to convert into TRT while found Attribute('axes') is " + "Variable type in squeeze2."; + return false; + } + std::vector axes; if (desc.HasAttr("axes")) { axes = PADDLE_GET_CONST(std::vector, desc.GetAttr("axes")); @@ -2002,6 +2035,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, } if (op_type == "reduce_sum" || op_type == "reduce_mean") { + if (!desc.HasAttr("dim", /*with_attr_var=*/false)) { + VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is " + "Variable type in " + << desc.Type(); + return false; + } + if (!(desc.HasAttr("keep_dim") && desc.HasAttr("dim") && desc.HasAttr("reduce_all"))) { VLOG(3) << "the " << op_type @@ -2265,34 +2305,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, return false; } -bool OpTeller::HasUnsupportAttrVar(const framework::OpDesc& desc) const { - const std::string op_type = desc.Type(); - auto has_attr_var = [&](const std::string& attr_name) -> bool { - // If Attribute is Variable(s), HasAttr() will return False - return !desc.HasAttr(attr_name, /*with_attr_var=*/false); - }; - std::unordered_map> attrs_info = { - {"dropout", {"dropout_prob"}}, - {"pool2d", {"ksize"}}, - {"arg_max", {"axis"}}, - {"reduce_mean", {"dim"}}, - {"reduce_sum", {"dim"}}, - {"squeeze2", {"axes"}}, - }; - - bool flag = false; - auto iter = attrs_info.find(op_type); - if (iter != attrs_info.end()) { - for (auto& attr_name : iter->second) { - if (has_attr_var(attr_name)) { - flag = true; - break; - } - } - } - return flag; -} - OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); } } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index d7d7d8b674001eabca2ffc974c92655300731dec..1a6ce092a18b43588f9db8372394155e596c73ab 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -73,14 +73,6 @@ class OpTeller { private: OpTeller(); - /* - * Some OpDescs Attribute support both constant value and dynamic - * runtime value (which is a Variable(s) type). But TensorRT maybe - * only support constant value Attribute, so we shall distinguish - * this case in time and return False in OpTeller.Tell(). - */ - bool HasUnsupportAttrVar(const framework::OpDesc& desc) const; - private: std::vector> tellers_; };