From fd1730e44367a2f0c18a46944a2d6640190194a9 Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Fri, 23 Dec 2022 19:06:23 +0800 Subject: [PATCH] [inference][trt] argmax support int64 dtype (#49167) --- paddle/fluid/inference/tensorrt/op_teller.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ad9b6156d9..e5e344e16c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -691,7 +691,7 @@ struct SimpleOpTypeSetTeller : public Teller { : -1; bool flatten = PADDLE_GET_CONST(bool, desc.GetAttr("flatten")); int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); - if (axis == 0 || flatten || dtype != 2) return false; + if (axis == 0 || flatten || (dtype != 2 && dtype != 3)) return false; } if (op_type == "arg_min") { -- GitLab