未验证 提交 856628d4 编写于 作者: R ruri 提交者: GitHub

Merge pull request #188 from shippingwang/refine_save

add print_interval and refine override
...@@ -144,9 +144,14 @@ def override(dl, ks, v): ...@@ -144,9 +144,14 @@ def override(dl, ks, v):
override(dl[k], ks[1:], v) override(dl[k], ks[1:], v)
else: else:
if len(ks) == 1: if len(ks) == 1:
assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl)) #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v) dl[ks[0]] = str2num(v)
else: else:
assert ks[0] in dl, (
'({}) doesn\'t exist in {}, a new dict field is invalid'.
format(ks[0], dl))
override(dl[ks[0]], ks[1:], v) override(dl[ks[0]], ks[1:], v)
......
...@@ -74,7 +74,7 @@ def main(args): ...@@ -74,7 +74,7 @@ def main(args):
compiled_valid_prog = program.compile(config, valid_prog) compiled_valid_prog = program.compile(config, valid_prog)
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1, program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
'eval') 'eval', config)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -410,6 +410,7 @@ def run(dataloader, ...@@ -410,6 +410,7 @@ def run(dataloader,
fetchs, fetchs,
epoch=0, epoch=0,
mode='train', mode='train',
config=None,
vdl_writer=None): vdl_writer=None):
""" """
Feed data to the model and fetch the measures and loss Feed data to the model and fetch the measures and loss
...@@ -443,11 +444,23 @@ def run(dataloader, ...@@ -443,11 +444,23 @@ def run(dataloader,
logger.scaler('loss', metrics[0][0], total_step, vdl_writer) logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1 total_step += 1
if mode == 'eval': if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
fetchs_str))
else: else:
epoch_str = "epoch:{:<3d}".format(epoch) epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx) step_str = "{:s} step:{:<4d}".format(mode, idx)
# Keep the first 10 batches statistics, They are important for develop
if epoch == 0 and idx < 10:
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN')))
else:
if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} {:s} {:s}".format( logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER") logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str, if idx == 0 else epoch_str,
......
...@@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH ...@@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \ python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \ --selected_gpus="0,1,2,3" \
tools/train.py \ tools/train.py \
-c ./configs/ResNet/ResNet50.yaml -c ./configs/ResNet/ResNet50.yaml \
-o print_interval=10
...@@ -110,21 +110,21 @@ def main(args): ...@@ -110,21 +110,21 @@ def main(args):
for epoch_id in range(config.epochs): for epoch_id in range(config.epochs):
# 1. train with train dataset # 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train', vdl_writer) epoch_id, 'train', config, vdl_writer)
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
# 2. validate with validate dataset # 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0: if config.validate and epoch_id % config.valid_interval == 0:
if config.get('use_ema'): if config.get('use_ema'):
logger.info(logger.coloring("EMA validate start...")) logger.info(logger.coloring("EMA validate start..."))
with ema.apply(exe): with ema.apply(exe):
top1_acc = program.run(valid_dataloader, exe, top1_acc = program.run(
compiled_valid_prog, valid_dataloader, exe, compiled_valid_prog,
valid_fetchs, epoch_id, 'valid') valid_fetchs, epoch_id, 'valid', config)
logger.info(logger.coloring("EMA validate over!")) logger.info(logger.coloring("EMA validate over!"))
top1_acc = program.run(valid_dataloader, exe, top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs, compiled_valid_prog, valid_fetchs,
epoch_id, 'valid') epoch_id, 'valid', config)
if top1_acc > best_top1_acc: if top1_acc > best_top1_acc:
best_top1_acc = top1_acc best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format( message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册