train.py 19.2 KB
Newer Older
1

2 3 4 5 6
import argparse
import time
from collections import OrderedDict
from datetime import datetime

7 8 9 10 11 12 13
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as DDP
    has_apex = True
except ImportError:
    has_apex = False

14 15
from data import *
from models import model_factory
16
from utils import *
17
from optim import Nadam, AdaBound
18
from loss import LabelSmoothingCrossEntropy
19
import scheduler
20 21

import torch
22
import torch.nn as nn
23 24 25
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
26
import torch.distributed as dist
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
import torchvision.utils

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
                    help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=224, metavar='N',
                    help='Image patch size (default: 224)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
                    help='initial input batch size for training (default: 0)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
                    help='epoch interval to decay LR')
58 59
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
60 61
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')
62 63
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "step"')
64 65 66 67
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
                    help='Dropout rate (default: 0.1)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
68 69
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
70 71
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
72
parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M',
73
                    help='weight decay (default: 0.0001)')
74 75
parser.add_argument('--smoothing', type=float, default=0.1, metavar='M',
                    help='label smoothing (default: 0.1)')
76 77 78 79 80 81
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
                    help='how many batches to wait before writing recovery checkpoint')
82
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
83 84 85 86 87 88 89 90 91
                    help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
92 93
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA amp for mixed precision training')
94 95
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
96
parser.add_argument("--local_rank", default=0, type=int)
97 98 99 100 101


def main():
    args = parser.parse_args()

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    r = -1
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        r = torch.distributed.get_rank()

    if args.distributed:
        print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.'
              % (args.world_size, r))
123
    else:
124 125 126 127 128 129 130 131 132 133 134 135 136
        print('Training with a single process with %d GPUs.' % args.num_gpu)

    output_dir = ''
    if args.local_rank == 0:
        if args.output:
            output_base = args.output
        else:
            output_base = './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(args.img_size)])
        output_dir = get_outdir(output_base, 'train', exp_name)
137 138 139 140

    batch_size = args.batch_size
    torch.manual_seed(args.seed)

141
    data_mean, data_std = get_model_meanstd(args.model)
142

143 144 145
    dataset_train = Dataset(os.path.join(args.data, 'train'))

    loader_train = create_loader(
146
        dataset_train,
147
        img_size=args.img_size,
148
        batch_size=batch_size,
149 150
        is_training=True,
        use_prefetcher=True,
151
        random_erasing=0.3,
152 153
        mean=data_mean,
        std=data_std,
154
        num_workers=args.workers,
155
        distributed=args.distributed,
156 157
    )

158
    dataset_eval = Dataset(os.path.join(args.data, 'validation'))
159

160
    loader_eval = create_loader(
161
        dataset_eval,
162
        img_size=args.img_size,
163
        batch_size=4 * args.batch_size,
164 165 166 167
        is_training=False,
        use_prefetcher=True,
        mean=data_mean,
        std=data_std,
168
        num_workers=args.workers,
169
        distributed=args.distributed,
170 171
    )

172 173 174 175 176 177 178
    model = model_factory.create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=1000,
        drop_rate=args.drop,
        global_pool=args.gp,
        checkpoint_path=args.initial_checkpoint)
179 180 181

    # optionally resume from a checkpoint
    start_epoch = 0 if args.start_epoch is None else args.start_epoch
182
    optimizer_state = None
183 184 185 186 187 188 189 190 191 192 193 194 195 196
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                new_state_dict = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    if k.startswith('module'):
                        name = k[7:]  # remove `module.`
                    else:
                        name = k
                    new_state_dict[name] = v
                model.load_state_dict(new_state_dict)
                if 'optimizer' in checkpoint:
197
                    optimizer_state = checkpoint['optimizer']
198 199 200 201 202 203 204 205
                print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
                start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch
            else:
                model.load_state_dict(checkpoint)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return False

206 207 208 209 210 211 212
    if args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

213
    if args.num_gpu > 1:
214 215 216 217 218
        if args.amp:
            print('Warning: AMP does not work well with nn.DataParallel, disabling. '
                  'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
219 220 221
    else:
        model.cuda()

222 223 224
    optimizer = create_optimizer(args, model.parameters())
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)
225

226 227 228 229 230 231 232 233 234 235 236
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O3')
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

237
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
238 239
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)
240

241 242 243
    saver = None
    if output_dir:
        saver = CheckpointSaver(checkpoint_dir=output_dir)
244 245 246
    best_loss = None
    try:
        for epoch in range(start_epoch, num_epochs):
247 248
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
249 250 251

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
252
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
253 254

            eval_metrics = validate(
255
                model, loader_eval, validate_loss_fn, args)
256 257

            if lr_scheduler is not None:
258
                lr_scheduler.step(epoch, eval_metrics['eval_loss'])
259

260
            update_summary(
R
Ross Wightman 已提交
261 262
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_loss is None)
263

264 265 266 267 268 269 270 271 272 273 274
            if saver is not None:
                # save proper checkpoint with eval metric
                best_loss = saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                    },
                    epoch=epoch + 1,
                    metric=eval_metrics['eval_loss'])
275 276 277 278 279 280 281 282

    except KeyboardInterrupt:
        pass
    print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))


def train_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
283
        lr_scheduler=None, saver=None, output_dir='', use_amp=False):
284 285 286 287 288 289 290 291

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
292
    last_idx = len(loader) - 1
293
    num_updates = epoch * len(loader)
294
    for batch_idx, (input, target) in enumerate(loader):
295
        last_batch = batch_idx == last_idx
296 297 298 299 300
        data_time_m.update(time.time() - end)

        output = model(input)

        loss = loss_fn(output, target)
301 302
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
303 304

        optimizer.zero_grad()
305 306 307 308 309
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
310 311
        optimizer.step()

312
        torch.cuda.synchronize()
313 314
        num_updates += 1

315
        batch_time_m.update(time.time() - end)
316
        if last_batch or batch_idx % args.log_interval == 0:
317 318 319
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                print('Train: {} [{}/{} ({:.0f}%)]  '
                      'Loss: {loss.val:.6f} ({loss.avg:.4f})  '
                      'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                      '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '
                      'LR: {lr:.4f}  '
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                    epoch,
                    batch_idx, len(loader),
                    100. * batch_idx / last_idx,
                    loss=losses_m,
                    batch_time=batch_time_m,
                    rate=input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                    lr=lr,
                    data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if args.local_rank == 0 and (
                saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
350
            save_epoch = epoch + 1 if last_batch else epoch
351
            saver.save_recovery({
352
                'epoch': save_epoch,
353 354 355 356 357
                'arch': args.model,
                'state_dict':  model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                },
358
                epoch=save_epoch,
359 360
                batch_idx=batch_idx)

361 362 363
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

364 365 366 367 368
        end = time.time()

    return OrderedDict([('train_loss', losses_m.avg)])


369
def validate(model, loader, loss_fn, args):
370 371 372 373 374 375 376 377
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()

    model.eval()

    end = time.time()
378
    last_idx = len(loader) - 1
379 380
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
381 382
            last_batch = batch_idx == last_idx

383
            output = model(input)
384
            if isinstance(output, (tuple, list)):
385 386 387
                output = output[0]

            # augmentation reduction
388
            reduce_factor = args.tta
389 390 391 392 393
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
394
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
395

396 397 398 399 400 401 402
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)
            else:
                reduced_loss = loss.data

403 404
            torch.cuda.synchronize()

405
            losses_m.update(reduced_loss.item(), input.size(0))
406 407 408 409 410
            prec1_m.update(prec1.item(), output.size(0))
            prec5_m.update(prec5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
411
            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
412 413 414 415 416
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                      'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                      'Prec@1 {top1.val:.4f} ({top1.avg:.4f})  '
                      'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
417
                    batch_idx, last_idx,
418 419 420 421 422 423 424 425
                    batch_time=batch_time_m, loss=losses_m,
                    top1=prec1_m, top5=prec5_m))

    metrics = OrderedDict([('eval_loss', losses_m.avg), ('eval_prec1', prec1_m.avg)])

    return metrics


426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
def create_optimizer(args, parameters):
    if args.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            parameters, lr=args.lr,
            momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    elif args.opt.lower() == 'adam':
        optimizer = optim.Adam(
            parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'nadam':
        optimizer = Nadam(
            parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'adabound':
        optimizer = AdaBound(
            parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
            final_lr=args.lr)
    elif args.opt.lower() == 'adadelta':
        optimizer = optim.Adadelta(
            parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'rmsprop':
        optimizer = optim.RMSprop(
            parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
            momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        assert False and "Invalid optimizer"
        raise ValueError
    return optimizer


def create_scheduler(args, optimizer):
    num_epochs = args.epochs
456
    #FIXME expose cycle parms of the scheduler config to arguments
457 458 459 460 461 462 463
    if args.sched == 'cosine':
        lr_scheduler = scheduler.CosineLRScheduler(
            optimizer,
            t_initial=num_epochs,
            t_mul=1.0,
            lr_min=1e-5,
            decay_rate=args.decay_rate,
464 465
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
466 467 468 469 470 471 472 473 474 475
            cycle_limit=1,
            t_in_epochs=True,
        )
        num_epochs = lr_scheduler.get_cycle_length() + 10
    elif args.sched == 'tanh':
        lr_scheduler = scheduler.TanhLRScheduler(
            optimizer,
            t_initial=num_epochs,
            t_mul=1.0,
            lr_min=1e-5,
476 477
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
478 479 480 481 482 483 484 485 486
            cycle_limit=1,
            t_in_epochs=True,
        )
        num_epochs = lr_scheduler.get_cycle_length() + 10
    else:
        lr_scheduler = scheduler.StepLRScheduler(
            optimizer,
            decay_t=args.decay_epochs,
            decay_rate=args.decay_rate,
487 488
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
489 490 491 492
        )
    return lr_scheduler, num_epochs


493 494 495 496 497 498 499
def reduce_tensor(tensor, n):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= n
    return rt


500 501
if __name__ == '__main__':
    main()