From 8d236eea3c2594cbc4ad9a3574a45bfa762a5750 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 8 Dec 2020 18:16:12 -0800 Subject: [PATCH] Hybrid auto-labelling support (#1599) * Introduce hybrid auto-labelling support * cleanup --- test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test.py b/test.py index 5b10e0f..5c8a70b 100644 --- a/test.py +++ b/test.py @@ -33,7 +33,8 @@ def test(data, dataloader=None, save_dir=Path(''), # for saving images save_txt=False, # for auto-labelling - save_conf=False, + save_hybrid=False, # for hybrid auto-labelling + save_conf=False, # save auto-label confidences plots=True, log_imgs=0): # number of logged images @@ -45,7 +46,6 @@ def test(data, else: # called directly set_logging() device = select_device(opt.device, batch_size=batch_size) - save_txt = opt.save_txt # save *.txt labels # Directories save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run @@ -115,7 +115,7 @@ def test(data, # Run NMS targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels - lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_txt else [] # for autolabelling + lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling t = time_synchronized() output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb) t1 += time_synchronized() - t @@ -292,6 +292,7 @@ if __name__ == '__main__': parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--verbose', action='store_true', help='report mAP by class') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') parser.add_argument('--project', default='runs/test', help='save to project/name') @@ -313,7 +314,8 @@ if __name__ == '__main__': opt.single_cls, opt.augment, opt.verbose, - save_txt=opt.save_txt, + save_txt=opt.save_txt | opt.save_hybrid, + save_hybrid=opt.save_hybrid, save_conf=opt.save_conf, ) -- GitLab