diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 78e300a8d730d58357b13dc8b0f54fd772086452..bb0fbdf6ca84821a2f7689656465f137d8f8c989 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -542,16 +542,6 @@ struct SimpleOpTypeSetTeller : public Teller { "the pass."; return false; } - - auto index_var_name = desc.Input("Index")[0]; - auto* index_var_desc = block->FindVar(index_var_name); - - // The index input must be int32 datatype. - if (index_var_desc->GetDataType() != - paddle::framework::proto::VarType_Type::VarType_Type_INT32) { - VLOG(3) << "gather op Index input data type must be int32"; - return false; - } #if !IS_TRT_VERSION_GE(7000) auto* x_var_desc = block->FindVar(desc.Input("X")[0]); const auto x_shape = x_var_desc->GetShape(); diff --git a/test/ir/inference/test_trt_convert_gather.py b/test/ir/inference/test_trt_convert_gather.py index 3c25dd6eff1c9ae21047ff5061960d5a5d8ef30c..69a2624b77e092eb116ca1fd913329dde74ae8bf 100644 --- a/test/ir/inference/test_trt_convert_gather.py +++ b/test/ir/inference/test_trt_convert_gather.py @@ -182,7 +182,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): if self.input_num == 3: return 0, 5 else: - if dynamic_shape and self.index_type_int32: + if dynamic_shape: return 1, 3 else: return 0, 4