train.py 17.6 KB
Newer Older
1

2 3 4 5
import argparse
import time
from datetime import datetime

6 7 8 9 10 11 12
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as DDP
    has_apex = True
except ImportError:
    has_apex = False

13
from data import *
14
from models import create_model, resume_checkpoint
15
from utils import *
16
from loss import LabelSmoothingCrossEntropy
17 18
from optim import create_optimizer
from scheduler import create_scheduler
19 20

import torch
21 22
import torch.nn as nn
import torch.distributed as dist
23 24 25 26 27 28 29 30 31
import torchvision.utils

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
32 33
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                    help='number of label classes (default: 1000)')
34 35 36 37 38 39 40 41 42 43 44 45
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
                    help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=224, metavar='N',
                    help='Image patch size (default: 224)')
46 47 48 49 50 51
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
52 53 54 55 56 57 58 59 60 61
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
                    help='initial input batch size for training (default: 0)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
                    help='epoch interval to decay LR')
62 63
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
64 65
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')
66 67
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "step"')
68 69
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
                    help='Dropout rate (default: 0.1)')
70 71 72 73
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
                    help='Random erase prob (default: 0.4)')
parser.add_argument('--repp', action='store_true', default=False,
                    help='Random erase per-pixel (default: False)')
74 75
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
76 77
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
78 79
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
80
parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M',
81
                    help='weight decay (default: 0.0001)')
82 83
parser.add_argument('--smoothing', type=float, default=0.1, metavar='M',
                    help='label smoothing (default: 0.1)')
84 85 86 87 88 89
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
                    help='how many batches to wait before writing recovery checkpoint')
90
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
91 92 93 94 95 96 97 98 99
                    help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
100 101
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA amp for mixed precision training')
102 103
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
104 105
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
                    help='Best metric (default: "prec1"')
106
parser.add_argument("--local_rank", default=0, type=int)
107 108 109 110 111


def main():
    args = parser.parse_args()

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    r = -1
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        r = torch.distributed.get_rank()

    if args.distributed:
131 132
        print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
              % (r, args.world_size))
133
    else:
134 135 136 137
        print('Training with a single process on %d GPUs.' % args.num_gpu)

    # FIXME seed handling for multi-process distributed?
    torch.manual_seed(args.seed)
138 139 140 141 142 143 144 145 146 147 148 149

    output_dir = ''
    if args.local_rank == 0:
        if args.output:
            output_base = args.output
        else:
            output_base = './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(args.img_size)])
        output_dir = get_outdir(output_base, 'train', exp_name)
150

151
    model = create_model(
152 153
        args.model,
        pretrained=args.pretrained,
154
        num_classes=args.num_classes,
155 156 157
        drop_rate=args.drop,
        global_pool=args.gp,
        checkpoint_path=args.initial_checkpoint)
158

159 160 161
    print('Model %s created, param count: %d' %
          (args.model, sum([m.numel() for m in model.parameters()])))

162
    data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
163

164
    # optionally resume from a checkpoint
165
    start_epoch = 0
166
    optimizer_state = None
167
    if args.resume:
168
        optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
169

170
    if args.num_gpu > 1:
171 172 173 174 175
        if args.amp:
            print('Warning: AMP does not work well with nn.DataParallel, disabling. '
                  'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
176 177 178
    else:
        model.cuda()

179 180 181
    optimizer = create_optimizer(args, model.parameters())
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)
182

183
    if has_apex and args.amp:
184
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
185 186 187 188 189 190 191 192 193
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

194
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
195 196
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
197 198
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)
199

200 201 202 203 204 205 206 207
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        print('Error: training folder does not exist at: %s' % train_dir)
        exit(1)
    dataset_train = Dataset(train_dir)

    loader_train = create_loader(
        dataset_train,
208
        input_size=data_config['input_size'],
209 210 211
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=True,
212 213
        rand_erase_prob=args.reprob,
        rand_erase_pp=args.repp,
214 215 216
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
217 218 219 220 221 222 223 224 225 226 227 228
        num_workers=args.workers,
        distributed=args.distributed,
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        print('Error: validation folder does not exist at: %s' % eval_dir)
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
229
        input_size=data_config['input_size'],
230 231 232
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=True,
233 234 235
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
236 237 238 239 240 241 242 243 244 245 246
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

247
    eval_metric = args.eval_metric
248 249
    saver = None
    if output_dir:
250 251 252 253
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
    best_metric = None
    best_epoch = None
254 255
    try:
        for epoch in range(start_epoch, num_epochs):
256 257
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
258 259 260

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
261
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
262 263

            eval_metrics = validate(
264
                model, loader_eval, validate_loss_fn, args)
265 266

            if lr_scheduler is not None:
267
                lr_scheduler.step(epoch, eval_metrics[eval_metric])
268

269
            update_summary(
R
Ross Wightman 已提交
270
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
271
                write_header=best_metric is None)
272

273 274
            if saver is not None:
                # save proper checkpoint with eval metric
275
                best_metric, best_epoch = saver.save_checkpoint({
276 277 278 279 280 281 282
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                    },
                    epoch=epoch + 1,
283
                    metric=eval_metrics[eval_metric])
284 285 286

    except KeyboardInterrupt:
        pass
287 288
    if best_metric is not None:
        print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
289 290 291 292


def train_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
293
        lr_scheduler=None, saver=None, output_dir='', use_amp=False):
294 295 296 297 298 299 300 301

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
302
    last_idx = len(loader) - 1
303
    num_updates = epoch * len(loader)
304
    for batch_idx, (input, target) in enumerate(loader):
305
        last_batch = batch_idx == last_idx
306 307 308 309 310
        data_time_m.update(time.time() - end)

        output = model(input)

        loss = loss_fn(output, target)
311 312
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
313 314

        optimizer.zero_grad()
315 316 317 318 319
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
320 321
        optimizer.step()

322
        torch.cuda.synchronize()
323 324
        num_updates += 1

325
        batch_time_m.update(time.time() - end)
326
        if last_batch or batch_idx % args.log_interval == 0:
327 328 329
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                print('Train: {} [{}/{} ({:.0f}%)]  '
                      'Loss: {loss.val:.6f} ({loss.avg:.4f})  '
                      'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                      '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '
                      'LR: {lr:.4f}  '
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                    epoch,
                    batch_idx, len(loader),
                    100. * batch_idx / last_idx,
                    loss=losses_m,
                    batch_time=batch_time_m,
                    rate=input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                    lr=lr,
                    data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if args.local_rank == 0 and (
                saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
360
            save_epoch = epoch + 1 if last_batch else epoch
361
            saver.save_recovery({
362
                'epoch': save_epoch,
363 364 365 366 367
                'arch': args.model,
                'state_dict':  model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                },
368
                epoch=save_epoch,
369 370
                batch_idx=batch_idx)

371 372 373
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

374 375
        end = time.time()

376
    return OrderedDict([('loss', losses_m.avg)])
377 378


379
def validate(model, loader, loss_fn, args):
380 381 382 383 384 385 386 387
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()

    model.eval()

    end = time.time()
388
    last_idx = len(loader) - 1
389 390
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
391 392
            last_batch = batch_idx == last_idx

393
            output = model(input)
394
            if isinstance(output, (tuple, list)):
395 396 397
                output = output[0]

            # augmentation reduction
398
            reduce_factor = args.tta
399 400 401 402 403
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
404
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
405

406 407 408 409 410 411 412
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)
            else:
                reduced_loss = loss.data

413 414
            torch.cuda.synchronize()

415
            losses_m.update(reduced_loss.item(), input.size(0))
416 417 418 419 420
            prec1_m.update(prec1.item(), output.size(0))
            prec5_m.update(prec5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
421
            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
422 423 424 425 426
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                      'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                      'Prec@1 {top1.val:.4f} ({top1.avg:.4f})  '
                      'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
427
                    batch_idx, last_idx,
428 429 430
                    batch_time=batch_time_m, loss=losses_m,
                    top1=prec1_m, top5=prec5_m))

431
    metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
432 433 434 435

    return metrics


436 437
def reduce_tensor(tensor, n):
    rt = tensor.clone()
438
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
439 440 441 442
    rt /= n
    return rt


443 444
if __name__ == '__main__':
    main()