diff --git a/test/ir/inference/test_trt_convert_multiclass_nms.py b/test/ir/inference/test_trt_convert_multiclass_nms.py index 3e40c66993553d827f0e136f40b30f42c3e23ec2..0033bf8aa4bdd98ec7964738de071d0d851eb77c 100644 --- a/test/ir/inference/test_trt_convert_multiclass_nms.py +++ b/test/ir/inference/test_trt_convert_multiclass_nms.py @@ -70,10 +70,10 @@ class TrtConvertMulticlassNMSTest(TrtLayerAutoScanTest): ) def generate_scores(batch, num_boxes, num_classes): - return np.arange( - batch * num_classes * num_boxes, dtype=np.float32 + max_value = batch * num_classes * num_boxes + return (1 / max_value) * np.arange( + max_value, dtype=np.float32 ).reshape([batch, num_classes, num_boxes]) - # return np.random.rand(batch, num_classes, num_boxes).astype(np.float32) for batch in [1, 2]: self.batch = batch