test.py 15.7 KB
Newer Older
G
Glenn Jocher 已提交
1
import argparse
G
updates  
Glenn Jocher 已提交
2
import json
3 4
import os
from pathlib import Path
G
Glenn Jocher 已提交
5
from threading import Thread
G
Glenn Jocher 已提交
6

7 8 9 10
import numpy as np
import torch
import yaml
from tqdm import tqdm
G
updates  
Glenn Jocher 已提交
11

12 13 14 15 16 17 18 19
from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
    non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
from utils.loss import compute_loss
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_synchronized
G
Glenn Jocher 已提交
20

21

22
def test(data,
G
updates  
Glenn Jocher 已提交
23
         weights=None,
24 25
         batch_size=32,
         imgsz=640,
G
updates  
Glenn Jocher 已提交
26
         conf_thres=0.001,
27
         iou_thres=0.6,  # for NMS
G
updates  
Glenn Jocher 已提交
28
         save_json=False,
G
updates  
Glenn Jocher 已提交
29
         single_cls=False,
G
Glenn Jocher 已提交
30
         augment=False,
31
         verbose=False,
G
updates  
Glenn Jocher 已提交
32
         model=None,
G
Glenn Jocher 已提交
33
         dataloader=None,
34 35
         save_dir=Path(''),  # for saving images
         save_txt=False,  # for auto-labelling
36 37
         save_hybrid=False,  # for hybrid auto-labelling
         save_conf=False,  # save auto-label confidences
38 39 40
         plots=True,
         log_imgs=0):  # number of logged images

G
updates  
Glenn Jocher 已提交
41
    # Initialize/load model and set device
42 43
    training = model is not None
    if training:  # called by train.py
G
updates  
Glenn Jocher 已提交
44
        device = next(model.parameters()).device  # get model device
45

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    else:  # called directly
        set_logging()
        device = select_device(opt.device, batch_size=batch_size)

        # Directories
        save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
        (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

        # Load model
        model = attempt_load(weights, map_location=device)  # load FP32 model
        imgsz = check_img_size(imgsz, s=model.stride.max())  # check img_size

        # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
        # if device.type != 'cpu' and torch.cuda.device_count() > 1:
        #     model = nn.DataParallel(model)

    # Half
    half = device.type != 'cpu'  # half precision only supported on CUDA
    if half:
        model.half()

    # Configure
    model.eval()
    is_coco = data.endswith('coco.yaml')  # is COCO dataset
    with open(data) as f:
        data = yaml.load(f, Loader=yaml.FullLoader)  # model dict
    check_dataset(data)  # check
    nc = 1 if single_cls else int(data['nc'])  # number of classes
G
updates  
Glenn Jocher 已提交
74 75
    iouv = torch.linspace(0.5, 0.95, 10).to(device)  # iou vector for mAP@0.5:0.95
    niou = iouv.numel()
G
Glenn Jocher 已提交
76

77 78 79 80 81 82 83
    # Logging
    log_imgs, wandb = min(log_imgs, 100), None  # ceil
    try:
        import wandb  # Weights & Biases
    except ImportError:
        log_imgs = 0

G
updates  
Glenn Jocher 已提交
84
    # Dataloader
85 86 87 88 89
    if not training:
        img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
        _ = model(img.half() if half else img) if device.type != 'cpu' else None  # run once
        path = data['test'] if opt.task == 'test' else data['val']  # path to val/test images
        dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
G
Glenn Jocher 已提交
90

G
Glenn Jocher 已提交
91
    seen = 0
92 93
    confusion_matrix = ConfusionMatrix(nc=nc)
    names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
G
updates  
Glenn Jocher 已提交
94
    coco91class = coco80_to_coco91_class()
95 96
    s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
    p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
G
updates  
Glenn Jocher 已提交
97
    loss = torch.zeros(3, device=device)
98 99 100 101 102
    jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
    for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
        img = img.to(device, non_blocking=True)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
103
        targets = targets.to(device)
104
        nb, _, height, width = img.shape  # batch size, channels, height, width
105

G
updates  
Glenn Jocher 已提交
106 107
        with torch.no_grad():
            # Run model
108 109 110
            t = time_synchronized()
            inf_out, train_out = model(img, augment=augment)  # inference and training outputs
            t0 += time_synchronized() - t
G
updates  
Glenn Jocher 已提交
111 112

            # Compute loss
113 114
            if training:
                loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3]  # box, obj, cls
G
Glenn Jocher 已提交
115

G
updates  
Glenn Jocher 已提交
116
            # Run NMS
117
            targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device)  # to pixels
118
            lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else []  # for autolabelling
119 120 121
            t = time_synchronized()
            output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb)
            t1 += time_synchronized() - t
G
updates  
Glenn Jocher 已提交
122 123

        # Statistics per image
G
Glenn Jocher 已提交
124
        for si, pred in enumerate(output):
G
Glenn Jocher 已提交
125
            labels = targets[targets[:, 0] == si, 1:]
G
updates  
Glenn Jocher 已提交
126 127
            nl = len(labels)
            tcls = labels[:, 0].tolist() if nl else []  # target class
128
            path = Path(paths[si])
G
updates  
Glenn Jocher 已提交
129
            seen += 1
G
Glenn Jocher 已提交
130

131
            if len(pred) == 0:
G
updates  
Glenn Jocher 已提交
132
                if nl:
G
updates  
Glenn Jocher 已提交
133
                    stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
G
Glenn Jocher 已提交
134 135
                continue

136 137 138
            # Predictions
            predn = pred.clone()
            scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1])  # native-space pred
G
updates  
Glenn Jocher 已提交
139

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
            # Append to text file
            if save_txt:
                gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]]  # normalization gain whwh
                for *xyxy, conf, cls in predn.tolist():
                    xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                    line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
                    with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
                        f.write(('%g ' * len(line)).rstrip() % line + '\n')

            # W&B logging
            if plots and len(wandb_images) < log_imgs:
                box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
                             "class_id": int(cls),
                             "box_caption": "%s %.3f" % (names[cls], conf),
                             "scores": {"class_score": conf},
                             "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
                boxes = {"predictions": {"box_data": box_data, "class_labels": names}}  # inference-space
                wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))
G
updates  
Glenn Jocher 已提交
158

G
updates  
Glenn Jocher 已提交
159 160
            # Append to pycocotools JSON dictionary
            if save_json:
G
updates  
Glenn Jocher 已提交
161
                # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
162 163
                image_id = int(path.stem) if path.stem.isnumeric() else path.stem
                box = xyxy2xywh(predn[:, :4])  # xywh
G
updates  
Glenn Jocher 已提交
164
                box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
165
                for p, b in zip(pred.tolist(), box.tolist()):
G
updates  
Glenn Jocher 已提交
166
                    jdict.append({'image_id': image_id,
167
                                  'category_id': coco91class[int(p[5])] if is_coco else int(p[5]),
168 169
                                  'bbox': [round(x, 3) for x in b],
                                  'score': round(p[4], 5)})
G
updates  
Glenn Jocher 已提交
170

G
updates  
Glenn Jocher 已提交
171
            # Assign all predictions as incorrect
G
updates  
Glenn Jocher 已提交
172
            correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
G
updates  
Glenn Jocher 已提交
173
            if nl:
G
updates  
Glenn Jocher 已提交
174
                detected = []  # target indices
G
updates  
Glenn Jocher 已提交
175 176 177
                tcls_tensor = labels[:, 0]

                # target boxes
178 179 180 181
                tbox = xywh2xyxy(labels[:, 1:5])
                scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1])  # native-space labels
                if plots:
                    confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))
G
Glenn Jocher 已提交
182

G
updates  
Glenn Jocher 已提交
183 184
                # Per target class
                for cls in torch.unique(tcls_tensor):
185 186
                    ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1)  # prediction indices
                    pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1)  # target indices
G
updates  
Glenn Jocher 已提交
187 188

                    # Search for detections
G
updates  
Glenn Jocher 已提交
189
                    if pi.shape[0]:
G
updates  
Glenn Jocher 已提交
190
                        # Prediction to target ious
191
                        ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1)  # best ious, indices
G
updates  
Glenn Jocher 已提交
192 193

                        # Append detections
194 195
                        detected_set = set()
                        for j in (ious > iouv[0]).nonzero(as_tuple=False):
G
updates  
Glenn Jocher 已提交
196
                            d = ti[i[j]]  # detected target
197 198
                            if d.item() not in detected_set:
                                detected_set.add(d.item())
G
updates  
Glenn Jocher 已提交
199
                                detected.append(d)
G
updates  
Glenn Jocher 已提交
200
                                correct[pi[j]] = ious[j] > iouv  # iou_thres is 1xn
G
updates  
Glenn Jocher 已提交
201 202
                                if len(detected) == nl:  # all targets already located in image
                                    break
G
Glenn Jocher 已提交
203

G
updates  
Glenn Jocher 已提交
204
            # Append statistics (correct, conf, pcls, tcls)
G
updates  
Glenn Jocher 已提交
205
            stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
G
Glenn Jocher 已提交
206

207
        # Plot images
208
        if plots and batch_i < 3:
G
Glenn Jocher 已提交
209 210 211 212
            f = save_dir / f'test_batch{batch_i}_labels.jpg'  # labels
            Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
            f = save_dir / f'test_batch{batch_i}_pred.jpg'  # predictions
            Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
213

G
updates  
Glenn Jocher 已提交
214
    # Compute statistics
G
updates  
Glenn Jocher 已提交
215
    stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy
216 217 218 219
    if len(stats) and stats[0].any():
        p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
        p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1)  # [P, R, AP@0.5, AP@0.5:0.95]
        mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
G
updates  
Glenn Jocher 已提交
220 221 222
        nt = np.bincount(stats[3].astype(np.int64), minlength=nc)  # number of targets per class
    else:
        nt = torch.zeros(1)
G
Glenn Jocher 已提交
223

G
updates  
Glenn Jocher 已提交
224
    # Print results
225 226
    pf = '%20s' + '%12.3g' * 6  # print format
    print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
G
Glenn Jocher 已提交
227

G
updates  
Glenn Jocher 已提交
228
    # Print results per class
G
updates  
Glenn Jocher 已提交
229
    if verbose and nc > 1 and len(stats):
G
updates  
Glenn Jocher 已提交
230
        for i, c in enumerate(ap_class):
231
            print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
G
Glenn Jocher 已提交
232

G
updates  
Glenn Jocher 已提交
233
    # Print speeds
234 235
    t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size)  # tuple
    if not training:
G
updates  
Glenn Jocher 已提交
236 237
        print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)

G
Glenn Jocher 已提交
238 239 240 241 242 243 244
    # Plots
    if plots:
        confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
        if wandb and wandb.run:
            wandb.log({"Images": wandb_images})
            wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})

G
updates  
Glenn Jocher 已提交
245
    # Save JSON
246 247
    if save_json and len(jdict):
        w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else ''  # weights
248
        anno_json = '../coco/annotations/instances_val2017.json'  # annotations json
249 250 251 252 253 254
        pred_json = str(save_dir / f"{w}_predictions.json")  # predictions json
        print('\nEvaluating pycocotools mAP... saving %s...' % pred_json)
        with open(pred_json, 'w') as f:
            json.dump(jdict, f)

        try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
G
Glenn Jocher 已提交
255 256
            from pycocotools.coco import COCO
            from pycocotools.cocoeval import COCOeval
257

258 259 260 261 262 263 264 265 266 267
            anno = COCO(anno_json)  # init annotations api
            pred = anno.loadRes(pred_json)  # init predictions api
            eval = COCOeval(anno, pred, 'bbox')
            if is_coco:
                eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files]  # image IDs to evaluate
            eval.evaluate()
            eval.accumulate()
            eval.summarize()
            map, map50 = eval.stats[:2]  # update results (mAP@0.5:0.95, mAP@0.5)
        except Exception as e:
268
            print(f'pycocotools unable to run: {e}')
G
updates  
Glenn Jocher 已提交
269

G
updates  
Glenn Jocher 已提交
270
    # Return results
271 272 273 274
    if not training:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        print(f"Results saved to {save_dir}{s}")
    model.float()  # for training
G
updates  
Glenn Jocher 已提交
275
    maps = np.zeros(nc) + map
G
Glenn Jocher 已提交
276 277
    for i, c in enumerate(ap_class):
        maps[c] = ap[i]
278
    return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
G
Glenn Jocher 已提交
279 280 281


if __name__ == '__main__':
282
    parser = argparse.ArgumentParser(prog='test.py')
283 284 285 286
    parser.add_argument('--weights', nargs='+', type=str, default='yolov3.pt', help='model.pt path(s)')
    parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
    parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
G
Glenn Jocher 已提交
287
    parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
G
updates  
Glenn Jocher 已提交
288
    parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS')
289 290 291
    parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
292
    parser.add_argument('--augment', action='store_true', help='augmented inference')
293 294
    parser.add_argument('--verbose', action='store_true', help='report mAP by class')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
295
    parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
296 297 298 299 300
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
    parser.add_argument('--project', default='runs/test', help='save to project/name')
    parser.add_argument('--name', default='exp', help='save to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
301
    opt = parser.parse_args()
302
    opt.save_json |= opt.data.endswith('coco.yaml')
G
Glenn Jocher 已提交
303
    opt.data = check_file(opt.data)  # check file
G
updates  
Glenn Jocher 已提交
304
    print(opt)
305

306 307
    if opt.task in ['val', 'test']:  # run normally
        test(opt.data,
G
updates  
Glenn Jocher 已提交
308 309 310 311
             opt.weights,
             opt.batch_size,
             opt.img_size,
             opt.conf_thres,
G
updates  
Glenn Jocher 已提交
312
             opt.iou_thres,
G
updates  
Glenn Jocher 已提交
313
             opt.save_json,
G
Glenn Jocher 已提交
314
             opt.single_cls,
315 316
             opt.augment,
             opt.verbose,
317 318
             save_txt=opt.save_txt | opt.save_hybrid,
             save_hybrid=opt.save_hybrid,
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
             save_conf=opt.save_conf,
             )

    elif opt.task == 'study':  # run over a range of settings and save/plot
        for weights in ['yolov3.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt']:
            f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem)  # filename to save to
            x = list(range(320, 800, 64))  # x axis
            y = []  # y axis
            for i in x:  # img-size
                print('\nRunning %s point %s...' % (f, i))
                r, _, t = test(opt.data, weights, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json,
                               plots=False)
                y.append(r + t)  # results and times
            np.savetxt(f, y, fmt='%10.4g')  # save
        os.system('zip -r study.zip study_*.txt')
        plot_study_txt(f, x)  # plot