diff --git a/inference.py b/inference.py index 9077cc0756a75cd166ea7050cf573a098d580360..3255a8d90d0f74457a5483f98b3e13ceb2d63289 100644 --- a/inference.py +++ b/inference.py @@ -70,7 +70,7 @@ def main(): logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - config = resolve_data_config(model, args) + config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, config, args) if args.num_gpu > 1: diff --git a/timm/data/config.py b/timm/data/config.py index 1675d2a97706dce1c6fd8516644690df765b8aef..8a83d19f19124011795916b4ee18a17e83adc21d 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -2,35 +2,43 @@ import logging from .constants import * -def resolve_data_config(model, args, default_cfg={}, verbose=True): +def resolve_data_config(args, default_cfg={}, model=None, verbose=True): new_config = {} default_cfg = default_cfg - if not default_cfg and hasattr(model, 'default_cfg'): + if not default_cfg and model is not None and hasattr(model, 'default_cfg'): default_cfg = model.default_cfg # Resolve input/image size - # FIXME grayscale/chans arg to use different # channels? in_chans = 3 + if 'chans' in args and args['chans'] is not None: + in_chans = args['chans'] + input_size = (in_chans, 224, 224) - if args.img_size is not None: - # FIXME support passing img_size as tuple, non-square - assert isinstance(args.img_size, int) - input_size = (in_chans, args.img_size, args.img_size) + if 'input_size' in args and args['input_size'] is not None: + assert isinstance(args['input_size'], (tuple, list)) + assert len(args['input_size']) == 3 + input_size = tuple(args['input_size']) + in_chans = input_size[0] # input_size overrides in_chans + elif 'img_size' in args and args['img_size'] is not None: + assert isinstance(args['img_size'], int) + input_size = (in_chans, args['img_size'], args['img_size']) elif 'input_size' in default_cfg: input_size = default_cfg['input_size'] new_config['input_size'] = input_size # resolve interpolation method - new_config['interpolation'] = 'bilinear' - if args.interpolation: - new_config['interpolation'] = args.interpolation + new_config['interpolation'] = 'bicubic' + if 'interpolation' in args and args['interpolation']: + new_config['interpolation'] = args['interpolation'] elif 'interpolation' in default_cfg: new_config['interpolation'] = default_cfg['interpolation'] # resolve dataset + model mean for normalization - new_config['mean'] = get_mean_by_model(args.model) - if args.mean is not None: - mean = tuple(args.mean) + new_config['mean'] = IMAGENET_DEFAULT_MEAN + if 'model' in args: + new_config['mean'] = get_mean_by_model(args['model']) + if 'mean' in args and args['mean'] is not None: + mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: @@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): new_config['mean'] = default_cfg['mean'] # resolve dataset + model std deviation for normalization - new_config['std'] = get_std_by_model(args.model) - if args.std is not None: - std = tuple(args.std) + new_config['std'] = IMAGENET_DEFAULT_STD + if 'model' in args: + new_config['std'] = get_std_by_model(args['model']) + if 'std' in args and args['std'] is not None: + std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: @@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): # resolve default crop percentage new_config['crop_pct'] = DEFAULT_CROP_PCT - if 'crop_pct' in default_cfg: + if 'crop_pct' in args and args['crop_pct'] is not None: + new_config['crop_pct'] = args['crop_pct'] + elif 'crop_pct' in default_cfg: new_config['crop_pct'] = default_cfg['crop_pct'] if verbose: @@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): return new_config -def get_mean_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_MEAN - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_MEAN - else: - return IMAGENET_DEFAULT_MEAN - - -def get_std_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_STD - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_STD - else: - return IMAGENET_DEFAULT_STD - - def get_mean_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_STD - elif 'ception' in model_name or 'nasnet' in model_name: + elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): return IMAGENET_INCEPTION_MEAN else: return IMAGENET_DEFAULT_MEAN @@ -96,7 +90,7 @@ def get_std_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DEFAULT_STD - elif 'ception' in model_name or 'nasnet' in model_name: + elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): return IMAGENET_INCEPTION_STD else: return IMAGENET_DEFAULT_STD diff --git a/timm/data/loader.py b/timm/data/loader.py index 777eb878ffaaead2d85fa9a508d71e096c16fb53..6a19b805a281aaae3e333df3c7712ff7a3c78cfa 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -86,6 +86,7 @@ def create_loader( use_prefetcher=True, rand_erase_prob=0., rand_erase_mode='const', + color_jitter=0.4, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -107,6 +108,7 @@ def create_loader( if is_training: transform = transforms_imagenet_train( img_size, + color_jitter=color_jitter, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, diff --git a/timm/data/transforms.py b/timm/data/transforms.py index bee505a226d23a163d93f8f144fd7d89ef88394f..1e1b054add048e09d86a8f567663b048fd1c24ad 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object): def transforms_imagenet_train( img_size=224, scale=(0.08, 1.0), - color_jitter=(0.4, 0.4, 0.4), + color_jitter=0.4, interpolation='random', random_erasing=0.4, random_erasing_mode='const', @@ -164,6 +164,14 @@ def transforms_imagenet_train( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ): + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + print(*color_jitter) tfl = [ RandomResizedCropAndInterpolation( diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 1f5890bca388928b2e2fcf540d8403741c0734b2..0642a1cb139226a9229ba0e23437173048822634 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -1430,7 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B1 """ default_cfg = default_cfgs['efficientnet_b1'] # NOTE for train, drop_rate should be 0.2 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( channel_multiplier=1.0, depth_multiplier=1.1, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1445,7 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B2 """ default_cfg = default_cfgs['efficientnet_b2'] # NOTE for train, drop_rate should be 0.3 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( channel_multiplier=1.1, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index b7de304a31dfd860a2eb910627b3f31dcc3a3f5a..1deff273b38492018e3f5023fc293ff1e365f614 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False): raise FileNotFoundError() -def resume_checkpoint(model, checkpoint_path, start_epoch=None): +def resume_checkpoint(model, checkpoint_path): optimizer_state = None + resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: @@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None): model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: optimizer_state = checkpoint['optimizer'] - start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) - start_epoch = 0 if start_epoch is None else start_epoch logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return optimizer_state, start_epoch + return optimizer_state, resume_epoch else: logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 59fcfc16393a62e357e7707344b661e4165a44a4..78e8460d46edbf0938d0f3721d9441a304273555 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -56,7 +56,7 @@ class Scheduler: def step(self, epoch: int, metric: float = None) -> None: self.metric = metric - values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch + values = self.get_epoch_values(epoch) if values is not None: self.update_groups(values) diff --git a/timm/utils.py b/timm/utils.py index 8d4418a61744510744762c4d336c776d6e14cfff..36355c2b13ecb6f829f63d6404f864872e83ad8b 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -83,7 +83,8 @@ class CheckpointSaver: 'arch': args.model, 'state_dict': get_state_dict(model), 'optimizer': optimizer.state_dict(), - 'args': args + 'args': args, + 'version': 2, # version < 2 increments epoch before save } if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema) diff --git a/train.py b/train.py index 7196305fbf5c55d01a3f5649f8fe3746431e7855..f7ecdd5db29a3711d76b44f036272c15c822db8c 100644 --- a/train.py +++ b/train.py @@ -27,22 +27,21 @@ import torchvision.utils torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='Training') +# Dataset / Model parameters parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', help='Name of model to train (default: "countception"') +parser.add_argument('--pretrained', action='store_true', default=False, + help='Start with pretrained version of specified network (if avail)') +parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', + help='Initialize model from this checkpoint (default: none)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes (default: 1000)') -parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', - help='Optimizer (default: "sgd"') -parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', - help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--gp', default='avg', type=str, metavar='POOL', help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') -parser.add_argument('--tta', type=int, default=0, metavar='N', - help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') -parser.add_argument('--pretrained', action='store_true', default=False, - help='Start with pretrained version of specified network (if avail)') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', @@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') -parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N', - help='initial input batch size for training (default: 0)') +parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', + help='Dropout rate (default: 0.)') +# Optimizer parameters +parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') +parser.add_argument('--weight-decay', type=float, default=0.0001, + help='weight decay (default: 0.0001)') +# Learning rate schedule parameters +parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "step"') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', + help='warmup learning rate (default: 0.0001)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 2)') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', @@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') -parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', - help='LR scheduler (default: "step"') -parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', - help='Dropout rate (default: 0.)') +# Augmentation parameters +parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') -parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') -parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', - help='warmup learning rate (default: 0.0001)') -parser.add_argument('--momentum', type=float, default=0.9, metavar='M', - help='SGD momentum (default: 0.9)') -parser.add_argument('--weight-decay', type=float, default=0.0001, - help='weight decay (default: 0.0001)') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') +# Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') +# Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') +# Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', @@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') -parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', - help='path to init checkpoint (default: none)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, @@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "prec1"') +parser.add_argument('--tta', type=int, default=0, metavar='N', + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) @@ -174,13 +181,13 @@ def main(): logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) + data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # optionally resume from a checkpoint - start_epoch = 0 optimizer_state = None + resume_epoch = None if args.resume: - optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch) + optimizer_state, resume_epoch = resume_checkpoint(model, args.resume) if args.num_gpu > 1: if args.amp: @@ -232,8 +239,15 @@ def main(): # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) + start_epoch = 0 + if args.start_epoch is not None: + # a specified start_epoch will always override the resume epoch + start_epoch = args.start_epoch + elif resume_epoch is not None: + start_epoch = resume_epoch if start_epoch > 0: lr_scheduler.step(start_epoch) + if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) @@ -255,6 +269,7 @@ def main(): use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, + color_jitter=args.color_jitter, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], @@ -327,7 +342,8 @@ def main(): eval_metrics = ema_eval_metrics if lr_scheduler is not None: - lr_scheduler.step(epoch, eval_metrics[eval_metric]) + # step LR for next epoch + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), @@ -338,9 +354,7 @@ def main(): save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, - epoch=epoch + 1, - model_ema=model_ema, - metric=save_metric) + epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass @@ -433,9 +447,8 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): - save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery( - model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx) + model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) diff --git a/validate.py b/validate.py index 199888ad2c27a7c6246ffc6362f50122858e06c8..280bc2604d2aecd78b68534f8769b65b7ec59390 100644 --- a/validate.py +++ b/validate.py @@ -71,7 +71,7 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(model, args) + data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: