diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py index eafd8debc27e79115ce87167546b6234df8644ba..1e1a83a40e48ac4be8b2dd5c280db4ec4660f973 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py @@ -35,7 +35,20 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest): return np.random.random([3, 4]).astype(np.float32) def generate_input3(attrs: List[Dict[str, Any]], batch): - return np.random.random([batch]).astype(np.int32) + if batch == 1: + return np.array([3]).astype(np.int32) + if batch == 2: + return np.array([1, 2]).astype(np.int32) + if batch == 4: + return np.array([1, 1, 0, 1]).astype(np.int32) + + def generate_lod(batch): + if batch == 1: + return [[0, 3]] + if batch == 2: + return [[0, 1, 3]] + if batch == 4: + return [[0, 1, 2, 2, 3]] for num_input in [0, 1]: for batch in [1, 2, 4]: @@ -96,7 +109,7 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest): data_gen=partial( generate_input2, dics, batch ), - lod=[[32, 3]], + lod=generate_lod(batch), ), }, ]