From a0e03418d821d9d6ca9ad12422d6a0dfdd602d56 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Mon, 10 Oct 2022 17:45:41 +0800 Subject: [PATCH] Fix gather op convert for Paddle-TensorRT (#46779) (#46825) * fix gather op convert to only support int32 index as input. * add ut --- paddle/fluid/inference/tensorrt/op_teller.cc | 10 +++ .../ir/inference/test_trt_convert_gather.py | 85 ++++++++++--------- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 5756efefa5b..f4c692d1251 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -534,6 +534,16 @@ struct SimpleOpTypeSetTeller : public Teller { "the pass."; return false; } + + auto index_var_name = desc.Input("Index")[0]; + auto* index_var_desc = block->FindVar(index_var_name); + + // The index input must be int32 datatype. + if (index_var_desc->GetDataType() != + paddle::framework::proto::VarType_Type::VarType_Type_INT32) { + VLOG(3) << "gather op Index input data type must be int32"; + return false; + } #if !IS_TRT_VERSION_GE(7000) auto* x_var_desc = block->FindVar(desc.Input("X")[0]); const auto x_shape = x_var_desc->GetShape(); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather.py index 5405f114651..25d0d48c8c3 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather.py @@ -42,6 +42,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): def generate_input2(index): return np.array(index).astype(np.int32) + def generate_input4(index): + return np.array(index).astype(np.int64) + def generate_input3(axis): return np.array([axis]).astype(np.int32) @@ -57,44 +60,48 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): "Index": ["index_data"], "Axis": ["axis_data"] }]: - self.shape = shape - self.axis = axis - self.input_num = len(input) - dics = [{"overwrite": overwrite, "axis": axis}] - ops_config = [{ - "op_type": "gather", - "op_inputs": input, - "op_outputs": { - "Out": ["output_data"] - }, - "op_attrs": dics[0] - }] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": - TensorConfig(data_gen=partial( - generate_input1, shape)), - "index_data": - TensorConfig(data_gen=partial( - generate_input2, index)), - } if len(input) == 2 else { - "input_data": - TensorConfig(data_gen=partial( - generate_input1, shape)), - "index_data": - TensorConfig(data_gen=partial( - generate_input2, index)), - "axis_data": - TensorConfig(data_gen=partial( - generate_input3, axis)), - }, - outputs=["output_data"]) - - yield program_config + for index_type_int32 in [True, False]: + self.shape = shape + self.axis = axis + self.input_num = len(input) + self.index_type_int32 = index_type_int32 + dics = [{"overwrite": overwrite, "axis": axis}] + ops_config = [{ + "op_type": "gather", + "op_inputs": input, + "op_outputs": { + "Out": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": + TensorConfig(data_gen=partial( + generate_input1, shape)), + "index_data": + TensorConfig(data_gen=partial( + generate_input2 + if index_type_int32 == + True else generate_input4, index)), + } if len(input) == 2 else { + "input_data": + TensorConfig(data_gen=partial( + generate_input1, shape)), + "index_data": + TensorConfig(data_gen=partial( + generate_input2, index)), + "axis_data": + TensorConfig(data_gen=partial( + generate_input3, axis)), + }, + outputs=["output_data"]) + + yield program_config def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): @@ -162,7 +169,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): if self.input_num == 3: return 0, 5 else: - if dynamic_shape: + if dynamic_shape and self.index_type_int32 == True: return 1, 3 else: return 0, 4 -- GitLab