diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 826d607c4ca65e0132e08d3eefcdbdbbaac88980..802a13ffd779299decfc3d622026c1b17c262a43 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -86,6 +86,49 @@ struct SimpleOpTypeSetTeller : public Teller { bool use_no_calib_int8 = false, bool with_dynamic_shape = false) override { const std::string op_type = desc.Type(); + + std::unordered_set control_set = {"conditional_block", + "while"}; + std::unordered_set feed_fetch_set = {"feed", "fetch"}; + if (control_set.find(op_type) != control_set.end()) { + return false; + } + + if (feed_fetch_set.find(op_type) != feed_fetch_set.end()) { + return false; + } + + // Dont.t allow fp64! + { + auto inputs = desc.Inputs(); + for (auto iter : inputs) { + for (auto var_name : iter.second) { + auto* block = desc.Block(); + if (block) { + auto* var_desc = block->FindVar(var_name); + auto dtype = var_desc->GetDataType(); + if (dtype == framework::proto::VarType::FP64) { + return false; + } + } + } + } + + auto outputs = desc.Outputs(); + for (auto iter : outputs) { + for (auto var_name : iter.second) { + auto* block = desc.Block(); + if (block) { + auto* var_desc = block->FindVar(var_name); + auto dtype = var_desc->GetDataType(); + if (dtype == framework::proto::VarType::FP64) { + return false; + } + } + } + } + } + // do not support the op which is labeled the `skip_quant` if ((desc.HasAttr("namescope") && PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) ==