From edb425ea4886896f24d0651f793fc5a889589e50 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 3 Jul 2019 22:28:07 -0700 Subject: [PATCH] Add crop_pct arg to validate, extra fields to csv output, 'all' filters pretrained --- validate.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/validate.py b/validate.py index 280bc26..a859a09 100644 --- a/validate.py +++ b/validate.py @@ -30,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', @@ -81,6 +83,7 @@ def validate(args): criterion = nn.CrossEntropyLoss().cuda() + crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( Dataset(args.data, load_bytes=args.tf_preprocessing), input_size=data_config['input_size'], @@ -90,7 +93,7 @@ def validate(args): mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, - crop_pct=1.0 if test_time_pool else data_config['crop_pct'], + crop_pct=crop_pct, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() @@ -124,16 +127,19 @@ def validate(args): 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' - 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' - 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' + 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict( - top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3), - top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3), - param_count=round(param_count / 1e6, 2)) + top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), + top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), + param_count=round(param_count / 1e6, 2), + img_size=data_config['input_size'][-1], + cropt_pct=crop_pct, + interpolation=data_config['interpolation']) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) @@ -155,7 +161,7 @@ def main(): if args.model == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True - model_names = list_models() + model_names = list_models(pretrained=True) model_cfgs = [(n, '') for n in model_names] elif not is_model(args.model): # model name doesn't exist, try as wildcard filter @@ -170,7 +176,8 @@ def main(): args.model = m args.checkpoint = c result = OrderedDict(model=args.model) - result.update(validate(args)) + r = validate(args) + result.update(r) if args.checkpoint: result['checkpoint'] = args.checkpoint dw = csv.DictWriter(cf, fieldnames=result.keys()) -- GitLab