train.py 19.7 KB
Newer Older
1

2 3 4 5
import argparse
import time
from datetime import datetime

6 7 8 9 10 11 12
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as DDP
    has_apex = True
except ImportError:
    has_apex = False

13
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
14
from models import create_model, resume_checkpoint
15
from utils import *
R
Ross Wightman 已提交
16
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
17 18
from optim import create_optimizer
from scheduler import create_scheduler
19 20

import torch
21 22
import torch.nn as nn
import torch.distributed as dist
23 24 25 26 27 28 29 30 31
import torchvision.utils

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
32 33
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                    help='number of label classes (default: 1000)')
34 35 36 37 38 39 40 41 42 43 44 45
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
                    help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=224, metavar='N',
                    help='Image patch size (default: 224)')
46 47 48 49 50 51
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
52 53 54 55 56 57 58 59 60 61
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
                    help='initial input batch size for training (default: 0)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
                    help='epoch interval to decay LR')
62 63
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
64 65
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')
66 67
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "step"')
68
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
69 70 71
                    help='Dropout rate (default: 0.)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
                    help='Random erase prob (default: 0.)')
72 73
parser.add_argument('--remode', type=str, default='const',
                    help='Random erase mode (default: "const")')
74 75
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
76 77
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
78 79
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
80
parser.add_argument('--weight-decay', type=float, default=0.0001,
81
                    help='weight decay (default: 0.0001)')
R
Ross Wightman 已提交
82 83 84 85
parser.add_argument('--mixup', type=float, default=0.0,
                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                    help='turn off mixup after this epoch, disabled if 0 (default: 0)')
86
parser.add_argument('--smoothing', type=float, default=0.1,
87
                    help='label smoothing (default: 0.1)')
88 89 90 91 92 93
parser.add_argument('--bn-tf', action='store_true', default=False,
                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
                    help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
                    help='BatchNorm epsilon override (if not None)')
94 95 96 97 98 99
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
                    help='how many batches to wait before writing recovery checkpoint')
100
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
101 102 103 104 105 106 107 108 109
                    help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
110 111
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA amp for mixed precision training')
112 113
parser.add_argument('--no-prefetcher', action='store_true', default=False,
                    help='disable fast prefetcher')
114 115
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
116 117
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
                    help='Best metric (default: "prec1"')
118
parser.add_argument("--local_rank", default=0, type=int)
119 120 121 122 123


def main():
    args = parser.parse_args()

124
    args.prefetcher = not args.no_prefetcher
125 126 127 128 129 130 131 132 133 134 135
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    r = -1
    if args.distributed:
136
        args.num_gpu = 1
137 138 139 140 141 142 143 144
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        r = torch.distributed.get_rank()

    if args.distributed:
145 146
        print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
              % (r, args.world_size))
147
    else:
148 149 150 151
        print('Training with a single process on %d GPUs.' % args.num_gpu)

    # FIXME seed handling for multi-process distributed?
    torch.manual_seed(args.seed)
152 153 154 155 156 157 158 159 160 161 162 163

    output_dir = ''
    if args.local_rank == 0:
        if args.output:
            output_base = args.output
        else:
            output_base = './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(args.img_size)])
        output_dir = get_outdir(output_base, 'train', exp_name)
164

165
    model = create_model(
166 167
        args.model,
        pretrained=args.pretrained,
168
        num_classes=args.num_classes,
169 170
        drop_rate=args.drop,
        global_pool=args.gp,
171 172 173
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
174
        checkpoint_path=args.initial_checkpoint)
175

176 177 178
    print('Model %s created, param count: %d' %
          (args.model, sum([m.numel() for m in model.parameters()])))

179
    data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
180

181
    # optionally resume from a checkpoint
182
    start_epoch = 0
183
    optimizer_state = None
184
    if args.resume:
185
        optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
186

187
    if args.num_gpu > 1:
188 189 190 191 192
        if args.amp:
            print('Warning: AMP does not work well with nn.DataParallel, disabling. '
                  'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
193 194 195
    else:
        model.cuda()

196
    optimizer = create_optimizer(args, model)
197 198
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)
199

200
    if has_apex and args.amp:
201
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
202 203 204 205 206 207 208 209 210
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

211
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
212 213
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
214 215
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)
216

217 218 219 220 221 222
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        print('Error: training folder does not exist at: %s' % train_dir)
        exit(1)
    dataset_train = Dataset(train_dir)

223 224 225 226
    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

227 228
    loader_train = create_loader(
        dataset_train,
229
        input_size=data_config['input_size'],
230 231
        batch_size=args.batch_size,
        is_training=True,
232
        use_prefetcher=args.prefetcher,
233
        rand_erase_prob=args.reprob,
234
        rand_erase_mode=args.remode,
235
        interpolation='random',  # FIXME cleanly resolve this? data_config['interpolation'],
236 237
        mean=data_config['mean'],
        std=data_config['std'],
238 239
        num_workers=args.workers,
        distributed=args.distributed,
240
        collate_fn=collate_fn,
241 242 243 244 245 246 247 248 249 250
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        print('Error: validation folder does not exist at: %s' % eval_dir)
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
251
        input_size=data_config['input_size'],
252 253
        batch_size=4 * args.batch_size,
        is_training=False,
254
        use_prefetcher=args.prefetcher,
255 256 257
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
258 259 260 261
        num_workers=args.workers,
        distributed=args.distributed,
    )

R
Ross Wightman 已提交
262 263 264 265 266
    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SparseLabelCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
267 268 269 270 271 272
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

273
    eval_metric = args.eval_metric
274 275
    saver = None
    if output_dir:
276 277 278 279
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
    best_metric = None
    best_epoch = None
280 281
    try:
        for epoch in range(start_epoch, num_epochs):
282 283
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
284 285 286

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
287
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
288 289

            eval_metrics = validate(
290
                model, loader_eval, validate_loss_fn, args)
291 292

            if lr_scheduler is not None:
293
                lr_scheduler.step(epoch, eval_metrics[eval_metric])
294

295
            update_summary(
R
Ross Wightman 已提交
296
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
297
                write_header=best_metric is None)
298

299 300
            if saver is not None:
                # save proper checkpoint with eval metric
301
                best_metric, best_epoch = saver.save_checkpoint({
302 303 304 305 306 307 308
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                    },
                    epoch=epoch + 1,
309
                    metric=eval_metrics[eval_metric])
310 311 312

    except KeyboardInterrupt:
        pass
313 314
    if best_metric is not None:
        print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
315 316 317 318


def train_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
319
        lr_scheduler=None, saver=None, output_dir='', use_amp=False):
320

321 322 323 324
    if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
        if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
            loader.mixup_enabled = False

325 326 327 328 329 330 331
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
332
    last_idx = len(loader) - 1
333
    num_updates = epoch * len(loader)
334
    for batch_idx, (input, target) in enumerate(loader):
335
        last_batch = batch_idx == last_idx
336
        data_time_m.update(time.time() - end)
337 338 339 340 341 342 343 344 345
        if not args.prefetcher:
            input = input.cuda()
            target = target.cuda()
            if args.mixup > 0.:
                lam = 1.
                if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
                    lam = np.random.beta(args.mixup, args.mixup)
                input.mul_(lam).add_(1 - lam, input.flip(0))
                target = mixup_target(target, args.num_classes, lam, args.smoothing)
R
Ross Wightman 已提交
346

347 348 349
        output = model(input)

        loss = loss_fn(output, target)
350 351
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
352 353

        optimizer.zero_grad()
354 355 356 357 358
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
359 360
        optimizer.step()

361
        torch.cuda.synchronize()
362 363
        num_updates += 1

364
        batch_time_m.update(time.time() - end)
365
        if last_batch or batch_idx % args.log_interval == 0:
366 367 368
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                print('Train: {} [{}/{} ({:.0f}%)]  '
                      'Loss: {loss.val:.6f} ({loss.avg:.4f})  '
                      'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                      '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '
                      'LR: {lr:.4f}  '
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                    epoch,
                    batch_idx, len(loader),
                    100. * batch_idx / last_idx,
                    loss=losses_m,
                    batch_time=batch_time_m,
                    rate=input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                    lr=lr,
                    data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if args.local_rank == 0 and (
                saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
399
            save_epoch = epoch + 1 if last_batch else epoch
400
            saver.save_recovery({
401
                'epoch': save_epoch,
402 403 404 405 406
                'arch': args.model,
                'state_dict':  model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                },
407
                epoch=save_epoch,
408 409
                batch_idx=batch_idx)

410 411 412
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

413 414
        end = time.time()

415
    return OrderedDict([('loss', losses_m.avg)])
416 417


418
def validate(model, loader, loss_fn, args):
419 420 421 422 423 424 425 426
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()

    model.eval()

    end = time.time()
427
    last_idx = len(loader) - 1
428 429
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
430 431
            last_batch = batch_idx == last_idx

432
            output = model(input)
433
            if isinstance(output, (tuple, list)):
434 435 436
                output = output[0]

            # augmentation reduction
437
            reduce_factor = args.tta
438 439 440 441 442
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
443
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
444

445 446 447 448 449 450 451
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)
            else:
                reduced_loss = loss.data

452 453
            torch.cuda.synchronize()

454
            losses_m.update(reduced_loss.item(), input.size(0))
455 456 457 458 459
            prec1_m.update(prec1.item(), output.size(0))
            prec5_m.update(prec5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
460
            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
461 462 463 464 465
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                      'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                      'Prec@1 {top1.val:.4f} ({top1.avg:.4f})  '
                      'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
466
                    batch_idx, last_idx,
467 468 469
                    batch_time=batch_time_m, loss=losses_m,
                    top1=prec1_m, top5=prec5_m))

470
    metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
471 472 473 474

    return metrics


475 476
def reduce_tensor(tensor, n):
    rt = tensor.clone()
477
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
478 479 480 481
    rt /= n
    return rt


482 483
if __name__ == '__main__':
    main()