train.py 25.0 KB
Newer Older
1

2 3
import argparse
import time
4
import logging
5
import yaml
6 7
from datetime import datetime

8 9 10
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as DDP
11
    from apex.parallel import convert_syncbn_model
12 13
    has_apex = True
except ImportError:
R
Ross Wightman 已提交
14
    from torch.nn.parallel import DistributedDataParallel as DDP
15 16
    has_apex = False

17 18 19 20 21 22
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from timm.models import create_model, resume_checkpoint
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
23 24

import torch
25
import torch.nn as nn
26 27 28 29
import torchvision.utils

torch.backends.cudnn.benchmark = True

30 31 32 33 34 35 36 37

# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')


38
parser = argparse.ArgumentParser(description='Training')
39
# Dataset / Model parameters
40 41 42 43
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
44 45 46 47 48 49
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='Resume full model and optimizer state from checkpoint (default: none)')
50 51
parser.add_argument('--no-resume-opt', action='store_true', default=False,
                    help='prevent resume of optimizer state when resuming model')
52 53
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                    help='number of label classes (default: 1000)')
54 55
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
                    help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
56 57
parser.add_argument('--img-size', type=int, default=None, metavar='N',
                    help='Image patch size (default: None => model default)')
58 59 60 61 62 63
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
64 65
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
                    help='Dropout rate (default: 0.)')
# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
                    help='weight decay (default: 0.0001)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
84 85
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
86 87 88 89 90 91
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
                    help='epoch interval to decay LR')
92 93
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
94 95
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
96 97
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')
98
# Augmentation parameters
99
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
100
                    help='Color jitter factor (default: 0.4)')
101 102
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
                    help='Use AutoAugment policy. "v0" or "original". (default: None)'),
103 104
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
                    help='Random erase prob (default: 0.)')
105 106
parser.add_argument('--remode', type=str, default='const',
                    help='Random erase mode (default: "const")')
107 108
parser.add_argument('--recount', type=int, default=1,
                    help='Random erase count (default: 1)')
R
Ross Wightman 已提交
109 110 111 112
parser.add_argument('--mixup', type=float, default=0.0,
                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                    help='turn off mixup after this epoch, disabled if 0 (default: 0)')
113
parser.add_argument('--smoothing', type=float, default=0.1,
114
                    help='label smoothing (default: 0.1)')
115
# Batch norm parameters (only works with gen_efficientnet based models currently)
116 117 118 119 120 121
parser.add_argument('--bn-tf', action='store_true', default=False,
                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
                    help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
                    help='BatchNorm epsilon override (if not None)')
122
# Model Exponential Moving Average
123 124 125 126 127 128
parser.add_argument('--model-ema', action='store_true', default=False,
                    help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
                    help='decay factor for model weights moving average (default: 0.9998)')
129
# Misc
130 131 132 133
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
134
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
135
                    help='how many batches to wait before writing recovery checkpoint')
136
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
137 138 139 140 141
                    help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
142 143
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA amp for mixed precision training')
144 145
parser.add_argument('--sync-bn', action='store_true',
                    help='enabling apex sync BN.')
146 147
parser.add_argument('--no-prefetcher', action='store_true', default=False,
                    help='disable fast prefetcher')
148 149
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
150 151
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
                    help='Best metric (default: "prec1"')
152 153
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
154
parser.add_argument("--local_rank", default=0, type=int)
155 156


157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text


174
def main():
175
    setup_default_logging()
176 177
    args, args_text = _parse_args()

178
    args.prefetcher = not args.no_prefetcher
179 180 181 182
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
183
            logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
184 185 186 187
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
188
    args.rank = 0  # global rank
189
    if args.distributed:
190
        args.num_gpu = 1
191 192
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
193
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
194
        args.world_size = torch.distributed.get_world_size()
195 196
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0
197 198

    if args.distributed:
199 200
        logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
201
    else:
202
        logging.info('Training with a single process on %d GPUs.' % args.num_gpu)
203

204
    torch.manual_seed(args.seed + args.rank)
205

206
    model = create_model(
207 208
        args.model,
        pretrained=args.pretrained,
209
        num_classes=args.num_classes,
210 211
        drop_rate=args.drop,
        global_pool=args.gp,
212 213 214
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
215
        checkpoint_path=args.initial_checkpoint)
216

R
Ross Wightman 已提交
217 218 219
    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))
220

221
    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
222

223
    if args.num_gpu > 1:
224
        if args.amp:
225 226
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
227 228
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
229 230 231
    else:
        model.cuda()

232
    optimizer = create_optimizer(args, model)
233

R
Ross Wightman 已提交
234
    use_amp = False
235
    if has_apex and args.amp:
236
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
237
        use_amp = True
R
Ross Wightman 已提交
238 239 240
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
241

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    resume_state = None

258 259
    model_ema = None
    if args.model_ema:
R
Ross Wightman 已提交
260
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
261 262 263 264 265 266
        model_ema = ModelEma(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '',
            resume=args.resume)

267
    if args.distributed:
R
Ross Wightman 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                if args.local_rank == 0:
                    logging.info('Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
            model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP
285

286
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
287 288 289 290 291 292
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
293
    if lr_scheduler is not None and start_epoch > 0:
294
        lr_scheduler.step(start_epoch)
295

296
    if args.local_rank == 0:
297
        logging.info('Scheduled epochs: {}'.format(num_epochs))
298

299 300
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
301
        logging.error('Training folder does not exist at: {}'.format(train_dir))
302 303 304
        exit(1)
    dataset_train = Dataset(train_dir)

305 306 307 308
    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

309 310
    loader_train = create_loader(
        dataset_train,
311
        input_size=data_config['input_size'],
312 313
        batch_size=args.batch_size,
        is_training=True,
314
        use_prefetcher=args.prefetcher,
315
        rand_erase_prob=args.reprob,
316
        rand_erase_mode=args.remode,
317
        rand_erase_count=args.recount,
318
        color_jitter=args.color_jitter,
319
        auto_augment=args.aa,
320
        interpolation='random',  # FIXME cleanly resolve this? data_config['interpolation'],
321 322
        mean=data_config['mean'],
        std=data_config['std'],
323 324
        num_workers=args.workers,
        distributed=args.distributed,
325
        collate_fn=collate_fn,
326 327
    )

328
    eval_dir = os.path.join(args.data, 'val')
329
    if not os.path.isdir(eval_dir):
330 331 332 333
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error('Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
334 335 336 337
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
338
        input_size=data_config['input_size'],
339 340
        batch_size=4 * args.batch_size,
        is_training=False,
341
        use_prefetcher=args.prefetcher,
342 343 344
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
345 346 347 348
        num_workers=args.workers,
        distributed=args.distributed,
    )

R
Ross Wightman 已提交
349 350
    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
351
        train_loss_fn = SoftTargetCrossEntropy().cuda()
R
Ross Wightman 已提交
352 353
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
354 355 356 357 358 359
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

360
    eval_metric = args.eval_metric
361 362
    best_metric = None
    best_epoch = None
363
    saver = None
364 365 366 367 368 369 370 371 372
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
373 374
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
375 376
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
377

378 379
    try:
        for epoch in range(start_epoch, num_epochs):
380 381
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
382 383 384

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
385 386 387 388
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                use_amp=use_amp, model_ema=model_ema)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
389

390 391 392 393
            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(
                    model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics
394 395

            if lr_scheduler is not None:
396 397
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
398

399
            update_summary(
R
Ross Wightman 已提交
400
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
401
                write_header=best_metric is None)
402

403 404
            if saver is not None:
                # save proper checkpoint with eval metric
405 406 407
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model, optimizer, args,
408
                    epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
409 410 411

    except KeyboardInterrupt:
        pass
412
    if best_metric is not None:
413
        logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
414 415 416 417


def train_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
418
        lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
419

420 421 422 423
    if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
        if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
            loader.mixup_enabled = False

424 425 426 427 428 429 430
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
431
    last_idx = len(loader) - 1
432
    num_updates = epoch * len(loader)
433
    for batch_idx, (input, target) in enumerate(loader):
434
        last_batch = batch_idx == last_idx
435
        data_time_m.update(time.time() - end)
436
        if not args.prefetcher:
437
            input, target = input.cuda(), target.cuda()
438 439 440 441
            if args.mixup > 0.:
                lam = 1.
                if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
                    lam = np.random.beta(args.mixup, args.mixup)
442
                input = input.mul(lam).add_(1 - lam, input.flip(0))
443
                target = mixup_target(target, args.num_classes, lam, args.smoothing)
R
Ross Wightman 已提交
444

445 446 447
        output = model(input)

        loss = loss_fn(output, target)
448 449
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
450 451

        optimizer.zero_grad()
452 453 454 455 456
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
457 458
        optimizer.step()

459
        torch.cuda.synchronize()
460 461
        if model_ema is not None:
            model_ema.update(model)
462 463
        num_updates += 1

464
        batch_time_m.update(time.time() - end)
465
        if last_batch or batch_idx % args.log_interval == 0:
466 467 468
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

469 470 471 472 473
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        batch_time=batch_time_m,
                        rate=input.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))
490 491 492 493 494 495 496 497

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

498 499 500
        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(
501
                model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
502

503 504 505
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

506
        end = time.time()
507 508 509 510
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()
511

512
    return OrderedDict([('loss', losses_m.avg)])
513 514


515
def validate(model, loader, loss_fn, args, log_suffix=''):
516 517 518 519 520 521 522 523
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()

    model.eval()

    end = time.time()
524
    last_idx = len(loader) - 1
525 526
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
527
            last_batch = batch_idx == last_idx
Z
Zhun Zhong 已提交
528 529 530
            if not args.prefetcher:
                input = input.cuda()
                target = target.cuda()
531

532
            output = model(input)
533
            if isinstance(output, (tuple, list)):
534 535 536
                output = output[0]

            # augmentation reduction
537
            reduce_factor = args.tta
538 539 540 541 542
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
543
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
544

545 546 547 548 549 550 551
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)
            else:
                reduced_loss = loss.data

552 553
            torch.cuda.synchronize()

554
            losses_m.update(reduced_loss.item(), input.size(0))
555 556 557 558 559
            prec1_m.update(prec1.item(), output.size(0))
            prec5_m.update(prec5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
560
            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
561
                log_name = 'Test' + log_suffix
562 563 564 565 566 567 568 569 570
                logging.info(
                    '{0}: [{1:>4d}/{2}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        log_name, batch_idx, last_idx,
                        batch_time=batch_time_m, loss=losses_m,
                        top1=prec1_m, top5=prec5_m))
571

572
    metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
573 574 575 576 577 578

    return metrics


if __name__ == '__main__':
    main()