From 2c12abd7fe3f3329b920afbb222df7fad787f5b6 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 27 Apr 2023 10:56:43 +0800 Subject: [PATCH] revert pr https://github.com/PaddlePaddle/Paddle/pull/46779 (#53373) --- paddle/fluid/inference/tensorrt/op_teller.cc | 10 ---------- test/ir/inference/test_trt_convert_gather.py | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 78e300a8d73..bb0fbdf6ca8 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -542,16 +542,6 @@ struct SimpleOpTypeSetTeller : public Teller { "the pass."; return false; } - - auto index_var_name = desc.Input("Index")[0]; - auto* index_var_desc = block->FindVar(index_var_name); - - // The index input must be int32 datatype. - if (index_var_desc->GetDataType() != - paddle::framework::proto::VarType_Type::VarType_Type_INT32) { - VLOG(3) << "gather op Index input data type must be int32"; - return false; - } #if !IS_TRT_VERSION_GE(7000) auto* x_var_desc = block->FindVar(desc.Input("X")[0]); const auto x_shape = x_var_desc->GetShape(); diff --git a/test/ir/inference/test_trt_convert_gather.py b/test/ir/inference/test_trt_convert_gather.py index 3c25dd6eff1..69a2624b77e 100644 --- a/test/ir/inference/test_trt_convert_gather.py +++ b/test/ir/inference/test_trt_convert_gather.py @@ -182,7 +182,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): if self.input_num == 3: return 0, 5 else: - if dynamic_shape and self.index_type_int32: + if dynamic_shape: return 1, 3 else: return 0, 4 -- GitLab