diff --git a/tools/train.py b/tools/train.py index 6dc362e845997f9f471d8320461feb3f83e8dfbb..4384b1528f4418035ca03218aeb7fa62cd94ac4b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -112,7 +112,7 @@ def main(args): if config.validate and epoch_id % config.valid_interval == 0: if config.get('use_ema'): logger.info(logger.coloring("EMA validate start...")) - with train_fetchs('ema').apply(exe): + with ema.apply(exe): top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, epoch_id, 'valid')