diff --git a/test/ir/inference/test_trt_convert_reshape.py b/test/ir/inference/test_trt_convert_reshape.py index 3f88b39003bb9dea6cf468b88d9cb1637d4d2b6b..c30d973651bad125911aad42dc0bc43749d32c03 100644 --- a/test/ir/inference/test_trt_convert_reshape.py +++ b/test/ir/inference/test_trt_convert_reshape.py @@ -431,5 +431,99 @@ class TrtConvertReshapeTest3(TrtLayerAutoScanTest): self.run_test() +class TrtConvertReshapeZeroDimsTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(attrs: List[Dict[str, Any]]): + if self.dims > 0: + self.input_shape = [1] * self.dims + return np.random.random(self.input_shape).astype(np.float32) + elif self.dims == 0: + self.input_shape = [] + return np.random.random([]).astype(np.float32) + + for dims in [0, 1, 2, 3]: + for shape in [ + [], + [1, 1], + ]: + dics = [ + { + "shape": shape, + }, + ] + self.dims = dims + dics_intput = [{"X": ["reshape_input"]}] + + ops_config = [ + { + "op_type": "reshape", + "op_inputs": dics_intput[0], + "op_outputs": {"Out": ["reshape_out"]}, + "op_attrs": dics[0], + } + ] + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "reshape_input": TensorConfig( + data_gen=partial(generate_input1, dics) + ) + }, + outputs=["reshape_out"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "reshape_input": self.input_shape + } + self.dynamic_shape.max_input_shape = { + "reshape_input": self.input_shape + } + self.dynamic_shape.opt_input_shape = { + "reshape_input": self.input_shape + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + # only test dynamic shape mode + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-3 + + def add_skip_trt_case(self): + pass + + def test(self): + self.add_skip_trt_case() + self.run_test() + + if __name__ == "__main__": unittest.main()