train.py 19.0 KB
Newer Older
1

2 3 4 5
import argparse
import time
from datetime import datetime

6 7 8 9 10 11 12
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as DDP
    has_apex = True
except ImportError:
    has_apex = False

13
from data import Dataset, create_loader, resolve_data_config
14
from models import create_model, resume_checkpoint
15
from utils import *
R
Ross Wightman 已提交
16
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
17 18
from optim import create_optimizer
from scheduler import create_scheduler
19 20

import torch
21 22
import torch.nn as nn
import torch.distributed as dist
23 24 25 26 27 28 29 30 31
import torchvision.utils

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
32 33
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                    help='number of label classes (default: 1000)')
34 35 36 37 38 39 40 41 42 43 44 45
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
                    help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=224, metavar='N',
                    help='Image patch size (default: 224)')
46 47 48 49 50 51
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
52 53 54 55 56 57 58 59 60 61
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
                    help='initial input batch size for training (default: 0)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
                    help='epoch interval to decay LR')
62 63
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
64 65
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')
66 67
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "step"')
68 69
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
                    help='Dropout rate (default: 0.1)')
70 71
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
                    help='Random erase prob (default: 0.4)')
72 73
parser.add_argument('--remode', type=str, default='const',
                    help='Random erase mode (default: "const")')
74 75
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
76 77
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
78 79
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
80
parser.add_argument('--weight-decay', type=float, default=0.0001,
81
                    help='weight decay (default: 0.0001)')
R
Ross Wightman 已提交
82 83 84 85
parser.add_argument('--mixup', type=float, default=0.0,
                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                    help='turn off mixup after this epoch, disabled if 0 (default: 0)')
86
parser.add_argument('--smoothing', type=float, default=0.1,
87
                    help='label smoothing (default: 0.1)')
88 89 90 91 92 93
parser.add_argument('--bn-tf', action='store_true', default=False,
                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
                    help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
                    help='BatchNorm epsilon override (if not None)')
94 95 96 97 98 99
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
                    help='how many batches to wait before writing recovery checkpoint')
100
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
101 102 103 104 105 106 107 108 109
                    help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
110 111
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA amp for mixed precision training')
112 113
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
114 115
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
                    help='Best metric (default: "prec1"')
116
parser.add_argument("--local_rank", default=0, type=int)
117 118 119 120 121


def main():
    args = parser.parse_args()

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    r = -1
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        r = torch.distributed.get_rank()

    if args.distributed:
141 142
        print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
              % (r, args.world_size))
143
    else:
144 145 146 147
        print('Training with a single process on %d GPUs.' % args.num_gpu)

    # FIXME seed handling for multi-process distributed?
    torch.manual_seed(args.seed)
148 149 150 151 152 153 154 155 156 157 158 159

    output_dir = ''
    if args.local_rank == 0:
        if args.output:
            output_base = args.output
        else:
            output_base = './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(args.img_size)])
        output_dir = get_outdir(output_base, 'train', exp_name)
160

161
    model = create_model(
162 163
        args.model,
        pretrained=args.pretrained,
164
        num_classes=args.num_classes,
165 166
        drop_rate=args.drop,
        global_pool=args.gp,
167 168 169
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
170
        checkpoint_path=args.initial_checkpoint)
171

172 173 174
    print('Model %s created, param count: %d' %
          (args.model, sum([m.numel() for m in model.parameters()])))

175
    data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
176

177
    # optionally resume from a checkpoint
178
    start_epoch = 0
179
    optimizer_state = None
180
    if args.resume:
181
        optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
182

183
    if args.num_gpu > 1:
184 185 186 187 188
        if args.amp:
            print('Warning: AMP does not work well with nn.DataParallel, disabling. '
                  'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
189 190 191
    else:
        model.cuda()

192
    optimizer = create_optimizer(args, model)
193 194
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)
195

196
    if has_apex and args.amp:
197
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
198 199 200 201 202 203 204 205 206
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

207
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
208 209
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
210 211
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)
212

213 214 215 216 217 218 219 220
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        print('Error: training folder does not exist at: %s' % train_dir)
        exit(1)
    dataset_train = Dataset(train_dir)

    loader_train = create_loader(
        dataset_train,
221
        input_size=data_config['input_size'],
222 223 224
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=True,
225
        rand_erase_prob=args.reprob,
226
        rand_erase_mode=args.remode,
227
        interpolation='random',  # FIXME cleanly resolve this? data_config['interpolation'],
228 229
        mean=data_config['mean'],
        std=data_config['std'],
230 231 232 233 234 235 236 237 238 239 240 241
        num_workers=args.workers,
        distributed=args.distributed,
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        print('Error: validation folder does not exist at: %s' % eval_dir)
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
242
        input_size=data_config['input_size'],
243 244 245
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=True,
246 247 248
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
249 250 251 252
        num_workers=args.workers,
        distributed=args.distributed,
    )

R
Ross Wightman 已提交
253 254 255 256 257
    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SparseLabelCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
258 259 260 261 262 263
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

264
    eval_metric = args.eval_metric
265 266
    saver = None
    if output_dir:
267 268 269 270
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
    best_metric = None
    best_epoch = None
271 272
    try:
        for epoch in range(start_epoch, num_epochs):
273 274
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
275 276 277

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
278
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
279 280

            eval_metrics = validate(
281
                model, loader_eval, validate_loss_fn, args)
282 283

            if lr_scheduler is not None:
284
                lr_scheduler.step(epoch, eval_metrics[eval_metric])
285

286
            update_summary(
R
Ross Wightman 已提交
287
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
288
                write_header=best_metric is None)
289

290 291
            if saver is not None:
                # save proper checkpoint with eval metric
292
                best_metric, best_epoch = saver.save_checkpoint({
293 294 295 296 297 298 299
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                    },
                    epoch=epoch + 1,
300
                    metric=eval_metrics[eval_metric])
301 302 303

    except KeyboardInterrupt:
        pass
304 305
    if best_metric is not None:
        print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
306 307 308 309


def train_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
310
        lr_scheduler=None, saver=None, output_dir='', use_amp=False):
311 312 313 314 315 316 317 318

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
319
    last_idx = len(loader) - 1
320
    num_updates = epoch * len(loader)
321
    for batch_idx, (input, target) in enumerate(loader):
322
        last_batch = batch_idx == last_idx
323 324
        data_time_m.update(time.time() - end)

R
Ross Wightman 已提交
325 326 327 328 329 330 331
        if args.mixup > 0.:
            lam = 1.
            if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
                lam = np.random.beta(args.mixup, args.mixup)
            input.mul_(lam).add_(1 - lam, input.flip(0))
            target = mixup_target(target, args.num_classes, lam, args.smoothing)

332 333 334
        output = model(input)

        loss = loss_fn(output, target)
335 336
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
337 338

        optimizer.zero_grad()
339 340 341 342 343
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
344 345
        optimizer.step()

346
        torch.cuda.synchronize()
347 348
        num_updates += 1

349
        batch_time_m.update(time.time() - end)
350
        if last_batch or batch_idx % args.log_interval == 0:
351 352 353
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                print('Train: {} [{}/{} ({:.0f}%)]  '
                      'Loss: {loss.val:.6f} ({loss.avg:.4f})  '
                      'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                      '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '
                      'LR: {lr:.4f}  '
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                    epoch,
                    batch_idx, len(loader),
                    100. * batch_idx / last_idx,
                    loss=losses_m,
                    batch_time=batch_time_m,
                    rate=input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                    lr=lr,
                    data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if args.local_rank == 0 and (
                saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
384
            save_epoch = epoch + 1 if last_batch else epoch
385
            saver.save_recovery({
386
                'epoch': save_epoch,
387 388 389 390 391
                'arch': args.model,
                'state_dict':  model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                },
392
                epoch=save_epoch,
393 394
                batch_idx=batch_idx)

395 396 397
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

398 399
        end = time.time()

400
    return OrderedDict([('loss', losses_m.avg)])
401 402


403
def validate(model, loader, loss_fn, args):
404 405 406 407 408 409 410 411
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()

    model.eval()

    end = time.time()
412
    last_idx = len(loader) - 1
413 414
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
415 416
            last_batch = batch_idx == last_idx

417
            output = model(input)
418
            if isinstance(output, (tuple, list)):
419 420 421
                output = output[0]

            # augmentation reduction
422
            reduce_factor = args.tta
423 424 425 426 427
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
428
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
429

430 431 432 433 434 435 436
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)
            else:
                reduced_loss = loss.data

437 438
            torch.cuda.synchronize()

439
            losses_m.update(reduced_loss.item(), input.size(0))
440 441 442 443 444
            prec1_m.update(prec1.item(), output.size(0))
            prec5_m.update(prec5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
445
            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
446 447 448 449 450
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                      'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                      'Prec@1 {top1.val:.4f} ({top1.avg:.4f})  '
                      'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
451
                    batch_idx, last_idx,
452 453 454
                    batch_time=batch_time_m, loss=losses_m,
                    top1=prec1_m, top5=prec5_m))

455
    metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
456 457 458 459

    return metrics


460 461
def reduce_tensor(tensor, n):
    rt = tensor.clone()
462
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
463 464 465 466
    rt /= n
    return rt


467 468
if __name__ == '__main__':
    main()