提交 c6b32cbe 编写于 作者: R Ross Wightman

A number of tweaks to arguments, epoch handling, config

* reorganize train args
* allow resolve_data_config to be used with dict args, not just arparse
* stop incrementing epoch before save, more consistent naming vs csv, etc
* update resume and start epoch handling to match above
* stop auto-incrementing epoch in scheduler
上级 9d653b68
...@@ -70,7 +70,7 @@ def main(): ...@@ -70,7 +70,7 @@ def main():
logging.info('Model %s created, param count: %d' % logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(model, args) config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, config, args) model, test_time_pool = apply_test_time_pool(model, config, args)
if args.num_gpu > 1: if args.num_gpu > 1:
......
...@@ -2,35 +2,43 @@ import logging ...@@ -2,35 +2,43 @@ import logging
from .constants import * from .constants import *
def resolve_data_config(model, args, default_cfg={}, verbose=True): def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
new_config = {} new_config = {}
default_cfg = default_cfg default_cfg = default_cfg
if not default_cfg and hasattr(model, 'default_cfg'): if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg default_cfg = model.default_cfg
# Resolve input/image size # Resolve input/image size
# FIXME grayscale/chans arg to use different # channels?
in_chans = 3 in_chans = 3
if 'chans' in args and args['chans'] is not None:
in_chans = args['chans']
input_size = (in_chans, 224, 224) input_size = (in_chans, 224, 224)
if args.img_size is not None: if 'input_size' in args and args['input_size'] is not None:
# FIXME support passing img_size as tuple, non-square assert isinstance(args['input_size'], (tuple, list))
assert isinstance(args.img_size, int) assert len(args['input_size']) == 3
input_size = (in_chans, args.img_size, args.img_size) input_size = tuple(args['input_size'])
in_chans = input_size[0] # input_size overrides in_chans
elif 'img_size' in args and args['img_size'] is not None:
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
elif 'input_size' in default_cfg: elif 'input_size' in default_cfg:
input_size = default_cfg['input_size'] input_size = default_cfg['input_size']
new_config['input_size'] = input_size new_config['input_size'] = input_size
# resolve interpolation method # resolve interpolation method
new_config['interpolation'] = 'bilinear' new_config['interpolation'] = 'bicubic'
if args.interpolation: if 'interpolation' in args and args['interpolation']:
new_config['interpolation'] = args.interpolation new_config['interpolation'] = args['interpolation']
elif 'interpolation' in default_cfg: elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation'] new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization # resolve dataset + model mean for normalization
new_config['mean'] = get_mean_by_model(args.model) new_config['mean'] = IMAGENET_DEFAULT_MEAN
if args.mean is not None: if 'model' in args:
mean = tuple(args.mean) new_config['mean'] = get_mean_by_model(args['model'])
if 'mean' in args and args['mean'] is not None:
mean = tuple(args['mean'])
if len(mean) == 1: if len(mean) == 1:
mean = tuple(list(mean) * in_chans) mean = tuple(list(mean) * in_chans)
else: else:
...@@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): ...@@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config['mean'] = default_cfg['mean'] new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization # resolve dataset + model std deviation for normalization
new_config['std'] = get_std_by_model(args.model) new_config['std'] = IMAGENET_DEFAULT_STD
if args.std is not None: if 'model' in args:
std = tuple(args.std) new_config['std'] = get_std_by_model(args['model'])
if 'std' in args and args['std'] is not None:
std = tuple(args['std'])
if len(std) == 1: if len(std) == 1:
std = tuple(list(std) * in_chans) std = tuple(list(std) * in_chans)
else: else:
...@@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): ...@@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
# resolve default crop percentage # resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in default_cfg: if 'crop_pct' in args and args['crop_pct'] is not None:
new_config['crop_pct'] = args['crop_pct']
elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct'] new_config['crop_pct'] = default_cfg['crop_pct']
if verbose: if verbose:
...@@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): ...@@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
return new_config return new_config
def get_mean_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_MEAN
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_STD
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
def get_mean_by_model(model_name): def get_mean_by_model(model_name):
model_name = model_name.lower() model_name = model_name.lower()
if 'dpn' in model_name: if 'dpn' in model_name:
return IMAGENET_DPN_STD return IMAGENET_DPN_STD
elif 'ception' in model_name or 'nasnet' in model_name: elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_MEAN return IMAGENET_INCEPTION_MEAN
else: else:
return IMAGENET_DEFAULT_MEAN return IMAGENET_DEFAULT_MEAN
...@@ -96,7 +90,7 @@ def get_std_by_model(model_name): ...@@ -96,7 +90,7 @@ def get_std_by_model(model_name):
model_name = model_name.lower() model_name = model_name.lower()
if 'dpn' in model_name: if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD return IMAGENET_DEFAULT_STD
elif 'ception' in model_name or 'nasnet' in model_name: elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_STD return IMAGENET_INCEPTION_STD
else: else:
return IMAGENET_DEFAULT_STD return IMAGENET_DEFAULT_STD
...@@ -86,6 +86,7 @@ def create_loader( ...@@ -86,6 +86,7 @@ def create_loader(
use_prefetcher=True, use_prefetcher=True,
rand_erase_prob=0., rand_erase_prob=0.,
rand_erase_mode='const', rand_erase_mode='const',
color_jitter=0.4,
interpolation='bilinear', interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
...@@ -107,6 +108,7 @@ def create_loader( ...@@ -107,6 +108,7 @@ def create_loader(
if is_training: if is_training:
transform = transforms_imagenet_train( transform = transforms_imagenet_train(
img_size, img_size,
color_jitter=color_jitter,
interpolation=interpolation, interpolation=interpolation,
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
mean=mean, mean=mean,
......
...@@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object): ...@@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object):
def transforms_imagenet_train( def transforms_imagenet_train(
img_size=224, img_size=224,
scale=(0.08, 1.0), scale=(0.08, 1.0),
color_jitter=(0.4, 0.4, 0.4), color_jitter=0.4,
interpolation='random', interpolation='random',
random_erasing=0.4, random_erasing=0.4,
random_erasing_mode='const', random_erasing_mode='const',
...@@ -164,6 +164,14 @@ def transforms_imagenet_train( ...@@ -164,6 +164,14 @@ def transforms_imagenet_train(
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD std=IMAGENET_DEFAULT_STD
): ):
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
print(*color_jitter)
tfl = [ tfl = [
RandomResizedCropAndInterpolation( RandomResizedCropAndInterpolation(
......
...@@ -1430,7 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): ...@@ -1430,7 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B1 """ """ EfficientNet-B1 """
default_cfg = default_cfgs['efficientnet_b1'] default_cfg = default_cfgs['efficientnet_b1']
# NOTE for train, drop_rate should be 0.2 # NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.1, channel_multiplier=1.0, depth_multiplier=1.1,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
...@@ -1445,7 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): ...@@ -1445,7 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B2 """ """ EfficientNet-B2 """
default_cfg = default_cfgs['efficientnet_b2'] default_cfg = default_cfgs['efficientnet_b2']
# NOTE for train, drop_rate should be 0.3 # NOTE for train, drop_rate should be 0.3
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.1, depth_multiplier=1.2, channel_multiplier=1.1, depth_multiplier=1.2,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
......
...@@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False): ...@@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
raise FileNotFoundError() raise FileNotFoundError()
def resume_checkpoint(model, checkpoint_path, start_epoch=None): def resume_checkpoint(model, checkpoint_path):
optimizer_state = None optimizer_state = None
resume_epoch = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
...@@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None): ...@@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint: if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer'] optimizer_state = checkpoint['optimizer']
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch
logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch return optimizer_state, resume_epoch
else: else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path)) logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
......
...@@ -56,7 +56,7 @@ class Scheduler: ...@@ -56,7 +56,7 @@ class Scheduler:
def step(self, epoch: int, metric: float = None) -> None: def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric self.metric = metric
values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch values = self.get_epoch_values(epoch)
if values is not None: if values is not None:
self.update_groups(values) self.update_groups(values)
......
...@@ -83,7 +83,8 @@ class CheckpointSaver: ...@@ -83,7 +83,8 @@ class CheckpointSaver:
'arch': args.model, 'arch': args.model,
'state_dict': get_state_dict(model), 'state_dict': get_state_dict(model),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'args': args 'args': args,
'version': 2, # version < 2 increments epoch before save
} }
if model_ema is not None: if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema) save_state['state_dict_ema'] = get_state_dict(model_ema)
......
...@@ -27,22 +27,21 @@ import torchvision.utils ...@@ -27,22 +27,21 @@ import torchvision.utils
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Training') parser = argparse.ArgumentParser(description='Training')
# Dataset / Model parameters
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"') help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N', parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)') help='number of label classes (default: 1000)')
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL', parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=None, metavar='N', parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)') help='Image patch size (default: None => model default)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
...@@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', ...@@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)') help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N', parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='initial input batch size for training (default: 0)') help='Dropout rate (default: 0.)')
# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)') help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N', parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
...@@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', ...@@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports') help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)') help='LR decay rate (default: 0.1)')
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', # Augmentation parameters
help='LR scheduler (default: "step"') parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Color jitter factor (default: 0.4)')
help='Dropout rate (default: 0.)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT', parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const', parser.add_argument('--remode', type=str, default='const',
help='Random erase mode (default: "const")') help='Random erase mode (default: "const")')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
parser.add_argument('--mixup', type=float, default=0.0, parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)') help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='turn off mixup after this epoch, disabled if 0 (default: 0)') help='turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1, parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)') help='label smoothing (default: 0.1)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False, parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None, parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)') help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None, parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)') help='BatchNorm epsilon override (if not None)')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False, parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights') help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998, parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)') help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S', parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)') help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N', parser.add_argument('--log-interval', type=int, default=50, metavar='N',
...@@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', ...@@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)') help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1, parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use') help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False, parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
...@@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH', ...@@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)') help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "prec1"') help='Best metric (default: "prec1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=0, type=int)
...@@ -174,13 +181,13 @@ def main(): ...@@ -174,13 +181,13 @@ def main():
logging.info('Model %s created, param count: %d' % logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# optionally resume from a checkpoint # optionally resume from a checkpoint
start_epoch = 0
optimizer_state = None optimizer_state = None
resume_epoch = None
if args.resume: if args.resume:
optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch) optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
if args.num_gpu > 1: if args.num_gpu > 1:
if args.amp: if args.amp:
...@@ -232,8 +239,15 @@ def main(): ...@@ -232,8 +239,15 @@ def main():
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if start_epoch > 0: if start_epoch > 0:
lr_scheduler.step(start_epoch) lr_scheduler.step(start_epoch)
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Scheduled epochs: {}'.format(num_epochs)) logging.info('Scheduled epochs: {}'.format(num_epochs))
...@@ -255,6 +269,7 @@ def main(): ...@@ -255,6 +269,7 @@ def main():
use_prefetcher=args.prefetcher, use_prefetcher=args.prefetcher,
rand_erase_prob=args.reprob, rand_erase_prob=args.reprob,
rand_erase_mode=args.remode, rand_erase_mode=args.remode,
color_jitter=args.color_jitter,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
...@@ -327,7 +342,8 @@ def main(): ...@@ -327,7 +342,8 @@ def main():
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step(epoch, eval_metrics[eval_metric]) # step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary( update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
...@@ -338,9 +354,7 @@ def main(): ...@@ -338,9 +354,7 @@ def main():
save_metric = eval_metrics[eval_metric] save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint( best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args, model, optimizer, args,
epoch=epoch + 1, epoch=epoch, model_ema=model_ema, metric=save_metric)
model_ema=model_ema,
metric=save_metric)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
...@@ -433,9 +447,8 @@ def train_epoch( ...@@ -433,9 +447,8 @@ def train_epoch(
if saver is not None and args.recovery_interval and ( if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery( saver.save_recovery(
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx) model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
......
...@@ -71,7 +71,7 @@ def validate(args): ...@@ -71,7 +71,7 @@ def validate(args):
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model %s created, param count: %d' % (args.model, param_count)) logging.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args) data_config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args) model, test_time_pool = apply_test_time_pool(model, data_config, args)
if args.num_gpu > 1: if args.num_gpu > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册