提交 c6b32cbe 编写于 作者: R Ross Wightman

A number of tweaks to arguments, epoch handling, config

* reorganize train args
* allow resolve_data_config to be used with dict args, not just arparse
* stop incrementing epoch before save, more consistent naming vs csv, etc
* update resume and start epoch handling to match above
* stop auto-incrementing epoch in scheduler
上级 9d653b68
......@@ -70,7 +70,7 @@ def main():
logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(model, args)
config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, config, args)
if args.num_gpu > 1:
......
......@@ -2,35 +2,43 @@ import logging
from .constants import *
def resolve_data_config(model, args, default_cfg={}, verbose=True):
def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
new_config = {}
default_cfg = default_cfg
if not default_cfg and hasattr(model, 'default_cfg'):
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
# Resolve input/image size
# FIXME grayscale/chans arg to use different # channels?
in_chans = 3
if 'chans' in args and args['chans'] is not None:
in_chans = args['chans']
input_size = (in_chans, 224, 224)
if args.img_size is not None:
# FIXME support passing img_size as tuple, non-square
assert isinstance(args.img_size, int)
input_size = (in_chans, args.img_size, args.img_size)
if 'input_size' in args and args['input_size'] is not None:
assert isinstance(args['input_size'], (tuple, list))
assert len(args['input_size']) == 3
input_size = tuple(args['input_size'])
in_chans = input_size[0] # input_size overrides in_chans
elif 'img_size' in args and args['img_size'] is not None:
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bilinear'
if args.interpolation:
new_config['interpolation'] = args.interpolation
new_config['interpolation'] = 'bicubic'
if 'interpolation' in args and args['interpolation']:
new_config['interpolation'] = args['interpolation']
elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = get_mean_by_model(args.model)
if args.mean is not None:
mean = tuple(args.mean)
new_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'model' in args:
new_config['mean'] = get_mean_by_model(args['model'])
if 'mean' in args and args['mean'] is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
......@@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = get_std_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
new_config['std'] = IMAGENET_DEFAULT_STD
if 'model' in args:
new_config['std'] = get_std_by_model(args['model'])
if 'std' in args and args['std'] is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
......@@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in default_cfg:
if 'crop_pct' in args and args['crop_pct'] is not None:
new_config['crop_pct'] = args['crop_pct']
elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
if verbose:
......@@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
return new_config
def get_mean_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_MEAN
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_STD
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
def get_mean_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_STD
elif 'ception' in model_name or 'nasnet' in model_name:
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
......@@ -96,7 +90,7 @@ def get_std_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD
elif 'ception' in model_name or 'nasnet' in model_name:
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
......@@ -86,6 +86,7 @@ def create_loader(
use_prefetcher=True,
rand_erase_prob=0.,
rand_erase_mode='const',
color_jitter=0.4,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
......@@ -107,6 +108,7 @@ def create_loader(
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
......
......@@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object):
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=(0.4, 0.4, 0.4),
color_jitter=0.4,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
......@@ -164,6 +164,14 @@ def transforms_imagenet_train(
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
print(*color_jitter)
tfl = [
RandomResizedCropAndInterpolation(
......
......@@ -1430,7 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B1 """
default_cfg = default_cfgs['efficientnet_b1']
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.1,
num_classes=num_classes, in_chans=in_chans, **kwargs)
......@@ -1445,7 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B2 """
default_cfg = default_cfgs['efficientnet_b2']
# NOTE for train, drop_rate should be 0.3
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet(
channel_multiplier=1.1, depth_multiplier=1.2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
......
......@@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
raise FileNotFoundError()
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
def resume_checkpoint(model, checkpoint_path):
optimizer_state = None
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
......@@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch
return optimizer_state, resume_epoch
else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
......
......@@ -56,7 +56,7 @@ class Scheduler:
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch
values = self.get_epoch_values(epoch)
if values is not None:
self.update_groups(values)
......
......@@ -83,7 +83,8 @@ class CheckpointSaver:
'arch': args.model,
'state_dict': get_state_dict(model),
'optimizer': optimizer.state_dict(),
'args': args
'args': args,
'version': 2, # version < 2 increments epoch before save
}
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema)
......
......@@ -27,22 +27,21 @@ import torchvision.utils
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Training')
# Dataset / Model parameters
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
......@@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
help='initial input batch size for training (default: 0)')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)')
# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
......@@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const',
help='Random erase mode (default: "const")')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
......@@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
......@@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "prec1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument("--local_rank", default=0, type=int)
......@@ -174,13 +181,13 @@ def main():
logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# optionally resume from a checkpoint
start_epoch = 0
optimizer_state = None
resume_epoch = None
if args.resume:
optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
if args.num_gpu > 1:
if args.amp:
......@@ -232,8 +239,15 @@ def main():
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
logging.info('Scheduled epochs: {}'.format(num_epochs))
......@@ -255,6 +269,7 @@ def main():
use_prefetcher=args.prefetcher,
rand_erase_prob=args.reprob,
rand_erase_mode=args.remode,
color_jitter=args.color_jitter,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
......@@ -327,7 +342,8 @@ def main():
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
lr_scheduler.step(epoch, eval_metrics[eval_metric])
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
......@@ -338,9 +354,7 @@ def main():
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args,
epoch=epoch + 1,
model_ema=model_ema,
metric=save_metric)
epoch=epoch, model_ema=model_ema, metric=save_metric)
except KeyboardInterrupt:
pass
......@@ -433,9 +447,8 @@ def train_epoch(
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery(
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
......
......@@ -71,7 +71,7 @@ def validate(args):
param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args)
data_config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args)
if args.num_gpu > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册