diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ad9b6156d9caf4a6edc31c3ace7eba78bdd1940f..e5e344e16cbb34379945d3e45fff64deda3800b8 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -691,7 +691,7 @@ struct SimpleOpTypeSetTeller : public Teller { : -1; bool flatten = PADDLE_GET_CONST(bool, desc.GetAttr("flatten")); int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); - if (axis == 0 || flatten || dtype != 2) return false; + if (axis == 0 || flatten || (dtype != 2 && dtype != 3)) return false; } if (op_type == "arg_min") {