未验证 提交 0c66065f 编写于 作者: F Frank Lin 提交者: GitHub

fix test trt convert nms (#56483)

上级 0012c8d5
...@@ -70,10 +70,10 @@ class TrtConvertMulticlassNMSTest(TrtLayerAutoScanTest): ...@@ -70,10 +70,10 @@ class TrtConvertMulticlassNMSTest(TrtLayerAutoScanTest):
) )
def generate_scores(batch, num_boxes, num_classes): def generate_scores(batch, num_boxes, num_classes):
return np.arange( max_value = batch * num_classes * num_boxes
batch * num_classes * num_boxes, dtype=np.float32 return (1 / max_value) * np.arange(
max_value, dtype=np.float32
).reshape([batch, num_classes, num_boxes]) ).reshape([batch, num_classes, num_boxes])
# return np.random.rand(batch, num_classes, num_boxes).astype(np.float32)
for batch in [1, 2]: for batch in [1, 2]:
self.batch = batch self.batch = batch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册