未验证 提交 3cbca60f 编写于 作者: W wangguanzhong 提交者: GitHub

fix trt roi_align test (#48570)

上级 e5cf75d8
......@@ -35,7 +35,20 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest):
return np.random.random([3, 4]).astype(np.float32)
def generate_input3(attrs: List[Dict[str, Any]], batch):
return np.random.random([batch]).astype(np.int32)
if batch == 1:
return np.array([3]).astype(np.int32)
if batch == 2:
return np.array([1, 2]).astype(np.int32)
if batch == 4:
return np.array([1, 1, 0, 1]).astype(np.int32)
def generate_lod(batch):
if batch == 1:
return [[0, 3]]
if batch == 2:
return [[0, 1, 3]]
if batch == 4:
return [[0, 1, 2, 2, 3]]
for num_input in [0, 1]:
for batch in [1, 2, 4]:
......@@ -96,7 +109,7 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest):
data_gen=partial(
generate_input2, dics, batch
),
lod=[[32, 3]],
lod=generate_lod(batch),
),
},
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册