train.py 16.9 KB
Newer Older
1

2 3 4 5
import argparse
import time
from datetime import datetime

6 7 8 9 10 11 12
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as DDP
    has_apex = True
except ImportError:
    has_apex = False

13
from data import *
14
from models import create_model, resume_checkpoint
15
from utils import *
16
from loss import LabelSmoothingCrossEntropy
17 18
from optim import create_optimizer
from scheduler import create_scheduler
19 20

import torch
21 22
import torch.nn as nn
import torch.distributed as dist
23 24 25 26 27 28 29 30 31
import torchvision.utils

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
32 33
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                    help='number of label classes (default: 1000)')
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
                    help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=224, metavar='N',
                    help='Image patch size (default: 224)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
                    help='initial input batch size for training (default: 0)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
                    help='epoch interval to decay LR')
56 57
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
58 59
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')
60 61
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "step"')
62 63
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
                    help='Dropout rate (default: 0.1)')
64 65 66 67
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
                    help='Random erase prob (default: 0.4)')
parser.add_argument('--repp', action='store_true', default=False,
                    help='Random erase per-pixel (default: False)')
68 69
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
70 71
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
72 73
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
74
parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M',
75
                    help='weight decay (default: 0.0001)')
76 77
parser.add_argument('--smoothing', type=float, default=0.1, metavar='M',
                    help='label smoothing (default: 0.1)')
78 79 80 81 82 83
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
                    help='how many batches to wait before writing recovery checkpoint')
84
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
85 86 87 88 89 90 91 92 93
                    help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
94 95
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA amp for mixed precision training')
96 97
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
98 99
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
                    help='Best metric (default: "prec1"')
100
parser.add_argument("--local_rank", default=0, type=int)
101 102 103 104 105


def main():
    args = parser.parse_args()

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    r = -1
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        r = torch.distributed.get_rank()

    if args.distributed:
125 126
        print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
              % (r, args.world_size))
127
    else:
128 129 130 131
        print('Training with a single process on %d GPUs.' % args.num_gpu)

    # FIXME seed handling for multi-process distributed?
    torch.manual_seed(args.seed)
132 133 134 135 136 137 138 139 140 141 142 143

    output_dir = ''
    if args.local_rank == 0:
        if args.output:
            output_base = args.output
        else:
            output_base = './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(args.img_size)])
        output_dir = get_outdir(output_base, 'train', exp_name)
144

145
    model = create_model(
146 147
        args.model,
        pretrained=args.pretrained,
148
        num_classes=args.num_classes,
149 150 151
        drop_rate=args.drop,
        global_pool=args.gp,
        checkpoint_path=args.initial_checkpoint)
152

153 154
    data_mean, data_std = get_mean_and_std(model, args)

155
    # optionally resume from a checkpoint
156
    start_epoch = 0
157
    optimizer_state = None
158
    if args.resume:
159
        start_epoch, optimizer_state = resume_checkpoint(model, args.resume, args.start_epoch)
160

161
    if args.num_gpu > 1:
162 163 164 165 166
        if args.amp:
            print('Warning: AMP does not work well with nn.DataParallel, disabling. '
                  'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
167 168 169
    else:
        model.cuda()

170 171 172
    optimizer = create_optimizer(args, model.parameters())
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)
173

174 175 176 177 178 179 180 181 182 183 184
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O3')
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

185
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
186 187
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
188 189
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)
190

191 192 193 194 195 196 197 198 199 200 201 202
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        print('Error: training folder does not exist at: %s' % train_dir)
        exit(1)
    dataset_train = Dataset(train_dir)

    loader_train = create_loader(
        dataset_train,
        img_size=args.img_size,
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=True,
203 204
        rand_erase_prob=args.reprob,
        rand_erase_pp=args.repp,
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        mean=data_mean,
        std=data_std,
        num_workers=args.workers,
        distributed=args.distributed,
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        print('Error: validation folder does not exist at: %s' % eval_dir)
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        img_size=args.img_size,
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=True,
        mean=data_mean,
        std=data_std,
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

236
    eval_metric = args.eval_metric
237 238
    saver = None
    if output_dir:
239 240 241 242
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
    best_metric = None
    best_epoch = None
243 244
    try:
        for epoch in range(start_epoch, num_epochs):
245 246
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
247 248 249

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
250
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
251 252

            eval_metrics = validate(
253
                model, loader_eval, validate_loss_fn, args)
254 255

            if lr_scheduler is not None:
256
                lr_scheduler.step(epoch, eval_metrics[eval_metric])
257

258
            update_summary(
R
Ross Wightman 已提交
259
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
260
                write_header=best_metric is None)
261

262 263
            if saver is not None:
                # save proper checkpoint with eval metric
264
                best_metric, best_epoch = saver.save_checkpoint({
265 266 267 268 269 270 271
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                    },
                    epoch=epoch + 1,
272
                    metric=eval_metrics[eval_metric])
273 274 275

    except KeyboardInterrupt:
        pass
276 277
    if best_metric is not None:
        print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
278 279 280 281


def train_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
282
        lr_scheduler=None, saver=None, output_dir='', use_amp=False):
283 284 285 286 287 288 289 290

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
291
    last_idx = len(loader) - 1
292
    num_updates = epoch * len(loader)
293
    for batch_idx, (input, target) in enumerate(loader):
294
        last_batch = batch_idx == last_idx
295 296 297 298 299
        data_time_m.update(time.time() - end)

        output = model(input)

        loss = loss_fn(output, target)
300 301
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
302 303

        optimizer.zero_grad()
304 305 306 307 308
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
309 310
        optimizer.step()

311
        torch.cuda.synchronize()
312 313
        num_updates += 1

314
        batch_time_m.update(time.time() - end)
315
        if last_batch or batch_idx % args.log_interval == 0:
316 317 318
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                print('Train: {} [{}/{} ({:.0f}%)]  '
                      'Loss: {loss.val:.6f} ({loss.avg:.4f})  '
                      'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                      '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '
                      'LR: {lr:.4f}  '
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                    epoch,
                    batch_idx, len(loader),
                    100. * batch_idx / last_idx,
                    loss=losses_m,
                    batch_time=batch_time_m,
                    rate=input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                    lr=lr,
                    data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if args.local_rank == 0 and (
                saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
349
            save_epoch = epoch + 1 if last_batch else epoch
350
            saver.save_recovery({
351
                'epoch': save_epoch,
352 353 354 355 356
                'arch': args.model,
                'state_dict':  model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                },
357
                epoch=save_epoch,
358 359
                batch_idx=batch_idx)

360 361 362
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

363 364
        end = time.time()

365
    return OrderedDict([('loss', losses_m.avg)])
366 367


368
def validate(model, loader, loss_fn, args):
369 370 371 372 373 374 375 376
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()

    model.eval()

    end = time.time()
377
    last_idx = len(loader) - 1
378 379
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
380 381
            last_batch = batch_idx == last_idx

382
            output = model(input)
383
            if isinstance(output, (tuple, list)):
384 385 386
                output = output[0]

            # augmentation reduction
387
            reduce_factor = args.tta
388 389 390 391 392
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
393
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
394

395 396 397 398 399 400 401
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)
            else:
                reduced_loss = loss.data

402 403
            torch.cuda.synchronize()

404
            losses_m.update(reduced_loss.item(), input.size(0))
405 406 407 408 409
            prec1_m.update(prec1.item(), output.size(0))
            prec5_m.update(prec5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
410
            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
411 412 413 414 415
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                      'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                      'Prec@1 {top1.val:.4f} ({top1.avg:.4f})  '
                      'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
416
                    batch_idx, last_idx,
417 418 419
                    batch_time=batch_time_m, loss=losses_m,
                    top1=prec1_m, top5=prec5_m))

420
    metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
421 422 423 424

    return metrics


425 426
def reduce_tensor(tensor, n):
    rt = tensor.clone()
427
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
428 429 430 431
    rt /= n
    return rt


432 433
if __name__ == '__main__':
    main()