From 0c66065f2c8b34f9b6c2c2e5fbb406a13a5f4500 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Fri, 25 Aug 2023 11:18:48 +0800 Subject: [PATCH] fix test trt convert nms (#56483) --- test/ir/inference/test_trt_convert_multiclass_nms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ir/inference/test_trt_convert_multiclass_nms.py b/test/ir/inference/test_trt_convert_multiclass_nms.py index 3e40c669935..0033bf8aa4b 100644 --- a/test/ir/inference/test_trt_convert_multiclass_nms.py +++ b/test/ir/inference/test_trt_convert_multiclass_nms.py @@ -70,10 +70,10 @@ class TrtConvertMulticlassNMSTest(TrtLayerAutoScanTest): ) def generate_scores(batch, num_boxes, num_classes): - return np.arange( - batch * num_classes * num_boxes, dtype=np.float32 + max_value = batch * num_classes * num_boxes + return (1 / max_value) * np.arange( + max_value, dtype=np.float32 ).reshape([batch, num_classes, num_boxes]) - # return np.random.rand(batch, num_classes, num_boxes).astype(np.float32) for batch in [1, 2]: self.batch = batch -- GitLab