未验证 提交 3cbca60f 编写于 作者: W wangguanzhong 提交者: GitHub

fix trt roi_align test (#48570)

上级 e5cf75d8
...@@ -35,7 +35,20 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest): ...@@ -35,7 +35,20 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest):
return np.random.random([3, 4]).astype(np.float32) return np.random.random([3, 4]).astype(np.float32)
def generate_input3(attrs: List[Dict[str, Any]], batch): def generate_input3(attrs: List[Dict[str, Any]], batch):
return np.random.random([batch]).astype(np.int32) if batch == 1:
return np.array([3]).astype(np.int32)
if batch == 2:
return np.array([1, 2]).astype(np.int32)
if batch == 4:
return np.array([1, 1, 0, 1]).astype(np.int32)
def generate_lod(batch):
if batch == 1:
return [[0, 3]]
if batch == 2:
return [[0, 1, 3]]
if batch == 4:
return [[0, 1, 2, 2, 3]]
for num_input in [0, 1]: for num_input in [0, 1]:
for batch in [1, 2, 4]: for batch in [1, 2, 4]:
...@@ -96,7 +109,7 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest): ...@@ -96,7 +109,7 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest):
data_gen=partial( data_gen=partial(
generate_input2, dics, batch generate_input2, dics, batch
), ),
lod=[[32, 3]], lod=generate_lod(batch),
), ),
}, },
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册