train.py 21.2 KB
Newer Older
1

2 3
import argparse
import time
4
import logging
5 6
from datetime import datetime

7 8 9
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as DDP
10
    from apex.parallel import convert_syncbn_model
11 12 13 14
    has_apex = True
except ImportError:
    has_apex = False

15 16 17 18 19 20
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from timm.models import create_model, resume_checkpoint
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
21 22

import torch
23
import torch.nn as nn
24 25 26 27 28 29 30 31 32
import torchvision.utils

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
33 34
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                    help='number of label classes (default: 1000)')
35 36 37 38 39 40 41 42 43 44
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
                    help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
45 46
parser.add_argument('--img-size', type=int, default=None, metavar='N',
                    help='Image patch size (default: None => model default)')
47 48 49 50 51 52
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
53 54 55 56 57 58 59 60 61 62
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
                    help='initial input batch size for training (default: 0)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
                    help='epoch interval to decay LR')
63 64
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
65 66
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')
67 68
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "step"')
69
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
70 71 72
                    help='Dropout rate (default: 0.)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
                    help='Random erase prob (default: 0.)')
73 74
parser.add_argument('--remode', type=str, default='const',
                    help='Random erase mode (default: "const")')
75 76
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
77 78
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
79 80
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
81
parser.add_argument('--weight-decay', type=float, default=0.0001,
82
                    help='weight decay (default: 0.0001)')
R
Ross Wightman 已提交
83 84 85 86
parser.add_argument('--mixup', type=float, default=0.0,
                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                    help='turn off mixup after this epoch, disabled if 0 (default: 0)')
87
parser.add_argument('--smoothing', type=float, default=0.1,
88
                    help='label smoothing (default: 0.1)')
89 90 91 92 93 94
parser.add_argument('--bn-tf', action='store_true', default=False,
                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
                    help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
                    help='BatchNorm epsilon override (if not None)')
95 96 97 98 99 100
parser.add_argument('--model-ema', action='store_true', default=False,
                    help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
                    help='decay factor for model weights moving average (default: 0.9998)')
101 102 103 104
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
105
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
106
                    help='how many batches to wait before writing recovery checkpoint')
107
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
108 109 110 111 112 113 114 115 116
                    help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
117 118
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA amp for mixed precision training')
119 120
parser.add_argument('--sync-bn', action='store_true',
                    help='enabling apex sync BN.')
121 122
parser.add_argument('--no-prefetcher', action='store_true', default=False,
                    help='disable fast prefetcher')
123 124
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
125 126
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
                    help='Best metric (default: "prec1"')
127
parser.add_argument("--local_rank", default=0, type=int)
128 129 130


def main():
131
    setup_default_logging()
132
    args = parser.parse_args()
133
    args.prefetcher = not args.no_prefetcher
134 135 136 137
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
138
            logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
139 140 141 142
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
143
    args.rank = 0  # global rank
144
    if args.distributed:
145
        args.num_gpu = 1
146 147
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
148
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
149
        args.world_size = torch.distributed.get_world_size()
150 151
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0
152 153

    if args.distributed:
154 155
        logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
156
    else:
157
        logging.info('Training with a single process on %d GPUs.' % args.num_gpu)
158

159
    torch.manual_seed(args.seed + args.rank)
160

161
    model = create_model(
162 163
        args.model,
        pretrained=args.pretrained,
164
        num_classes=args.num_classes,
165 166
        drop_rate=args.drop,
        global_pool=args.gp,
167 168 169
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
170
        checkpoint_path=args.initial_checkpoint)
171

172 173
    logging.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))
174

175
    data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
176

177
    # optionally resume from a checkpoint
178
    start_epoch = 0
179
    optimizer_state = None
180
    if args.resume:
181
        optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
182

183
    if args.num_gpu > 1:
184
        if args.amp:
185 186
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
187 188
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
189
    else:
190 191
        if args.distributed and args.sync_bn and has_apex:
            model = convert_syncbn_model(model)
192 193
        model.cuda()

194
    optimizer = create_optimizer(args, model)
195 196
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)
197

198
    if has_apex and args.amp:
199
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
200
        use_amp = True
201
        logging.info('AMP enabled')
202 203
    else:
        use_amp = False
204
        logging.info('AMP disabled')
205

206 207 208 209 210 211 212 213
    model_ema = None
    if args.model_ema:
        model_ema = ModelEma(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '',
            resume=args.resume)

214 215
    if args.distributed:
        model = DDP(model, delay_allreduce=True)
216 217 218 219
        if model_ema is not None and not args.model_ema_force_cpu:
            # must also distribute EMA model to allow validation
            model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
            model_ema.ema_has_module = True
220

221
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
222 223
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
224
    if args.local_rank == 0:
225
        logging.info('Scheduled epochs: {}'.format(num_epochs))
226

227 228
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
229
        logging.error('Training folder does not exist at: {}'.format(train_dir))
230 231 232
        exit(1)
    dataset_train = Dataset(train_dir)

233 234 235 236
    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

237 238
    loader_train = create_loader(
        dataset_train,
239
        input_size=data_config['input_size'],
240 241
        batch_size=args.batch_size,
        is_training=True,
242
        use_prefetcher=args.prefetcher,
243
        rand_erase_prob=args.reprob,
244
        rand_erase_mode=args.remode,
245
        interpolation='random',  # FIXME cleanly resolve this? data_config['interpolation'],
246 247
        mean=data_config['mean'],
        std=data_config['std'],
248 249
        num_workers=args.workers,
        distributed=args.distributed,
250
        collate_fn=collate_fn,
251 252 253 254
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
255
        logging.error('Validation folder does not exist at: {}'.format(eval_dir))
256 257 258 259 260
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
261
        input_size=data_config['input_size'],
262 263
        batch_size=4 * args.batch_size,
        is_training=False,
264
        use_prefetcher=args.prefetcher,
265 266 267
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
268 269 270 271
        num_workers=args.workers,
        distributed=args.distributed,
    )

R
Ross Wightman 已提交
272 273
    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
274
        train_loss_fn = SoftTargetCrossEntropy().cuda()
R
Ross Wightman 已提交
275 276
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
277 278 279 280 281 282
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

283
    eval_metric = args.eval_metric
284 285
    best_metric = None
    best_epoch = None
286
    saver = None
287 288 289 290 291 292 293 294 295
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
296 297
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
298

299 300
    try:
        for epoch in range(start_epoch, num_epochs):
301 302
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
303 304 305

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
306 307 308 309
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                use_amp=use_amp, model_ema=model_ema)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
310

311 312 313 314
            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(
                    model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics
315 316

            if lr_scheduler is not None:
317
                lr_scheduler.step(epoch, eval_metrics[eval_metric])
318

319
            update_summary(
R
Ross Wightman 已提交
320
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
321
                write_header=best_metric is None)
322

323 324
            if saver is not None:
                # save proper checkpoint with eval metric
325 326 327
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model, optimizer, args,
328
                    epoch=epoch + 1,
329 330
                    model_ema=model_ema,
                    metric=save_metric)
331 332 333

    except KeyboardInterrupt:
        pass
334
    if best_metric is not None:
335
        logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
336 337 338 339


def train_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
340
        lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
341

342 343 344 345
    if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
        if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
            loader.mixup_enabled = False

346 347 348 349 350 351 352
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
353
    last_idx = len(loader) - 1
354
    num_updates = epoch * len(loader)
355
    for batch_idx, (input, target) in enumerate(loader):
356
        last_batch = batch_idx == last_idx
357
        data_time_m.update(time.time() - end)
358 359 360 361 362 363 364 365 366
        if not args.prefetcher:
            input = input.cuda()
            target = target.cuda()
            if args.mixup > 0.:
                lam = 1.
                if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
                    lam = np.random.beta(args.mixup, args.mixup)
                input.mul_(lam).add_(1 - lam, input.flip(0))
                target = mixup_target(target, args.num_classes, lam, args.smoothing)
R
Ross Wightman 已提交
367

368 369 370
        output = model(input)

        loss = loss_fn(output, target)
371 372
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
373 374

        optimizer.zero_grad()
375 376 377 378 379
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
380 381
        optimizer.step()

382
        torch.cuda.synchronize()
383 384
        if model_ema is not None:
            model_ema.update(model)
385 386
        num_updates += 1

387
        batch_time_m.update(time.time() - end)
388
        if last_batch or batch_idx % args.log_interval == 0:
389 390 391
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

392 393 394 395 396
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        batch_time=batch_time_m,
                        rate=input.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))
413 414 415 416 417 418 419 420

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

421 422
        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
423
            save_epoch = epoch + 1 if last_batch else epoch
424 425
            saver.save_recovery(
                model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
426

427 428 429
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

430 431
        end = time.time()

432
    return OrderedDict([('loss', losses_m.avg)])
433 434


435
def validate(model, loader, loss_fn, args, log_suffix=''):
436 437 438 439 440 441 442 443
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()

    model.eval()

    end = time.time()
444
    last_idx = len(loader) - 1
445 446
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
447
            last_batch = batch_idx == last_idx
Z
Zhun Zhong 已提交
448 449 450
            if not args.prefetcher:
                input = input.cuda()
                target = target.cuda()
451

452
            output = model(input)
453
            if isinstance(output, (tuple, list)):
454 455 456
                output = output[0]

            # augmentation reduction
457
            reduce_factor = args.tta
458 459 460 461 462
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
463
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
464

465 466 467 468 469 470 471
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)
            else:
                reduced_loss = loss.data

472 473
            torch.cuda.synchronize()

474
            losses_m.update(reduced_loss.item(), input.size(0))
475 476 477 478 479
            prec1_m.update(prec1.item(), output.size(0))
            prec5_m.update(prec5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
480
            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
481
                log_name = 'Test' + log_suffix
482 483 484 485 486 487 488 489 490
                logging.info(
                    '{0}: [{1:>4d}/{2}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        log_name, batch_idx, last_idx,
                        batch_time=batch_time_m, loss=losses_m,
                        top1=prec1_m, top5=prec5_m))
491

492
    metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
493 494 495 496 497 498

    return metrics


if __name__ == '__main__':
    main()