提交 d3bad33f 编写于 作者: S shippingwang

add print_interval and refine override

上级 9d3f36b7
......@@ -144,9 +144,14 @@ def override(dl, ks, v):
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
assert ks[0] in dl, (
'({}) doesn\'t exist in {}, a new dict field is invalid'.
format(ks[0], dl))
override(dl[ks[0]], ks[1:], v)
......
......@@ -74,7 +74,7 @@ def main(args):
compiled_valid_prog = program.compile(config, valid_prog)
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
'eval')
'eval', config)
if __name__ == '__main__':
......
......@@ -410,6 +410,7 @@ def run(dataloader,
fetchs,
epoch=0,
mode='train',
config=None,
vdl_writer=None):
"""
Feed data to the model and fetch the measures and loss
......@@ -443,16 +444,28 @@ def run(dataloader,
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1
if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
if idx % config.get('print_interval', 1) == 0:
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx,
fetchs_str))
else:
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN')))
# Keep the first 10 batches statistics, They are important for develop
if epoch == 0 and idx < 10:
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN')))
else:
if idx % config.get('print_interval', 1) == 0:
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN')))
end_str = ''.join([str(m.mean) + ' '
for m in metric_list] + [batch_time.total]) + 's'
......
......@@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
tools/train.py \
-c ./configs/ResNet/ResNet50.yaml
-c ./configs/ResNet/ResNet50.yaml \
-o print_interval=10
......@@ -110,21 +110,21 @@ def main(args):
for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train', vdl_writer)
epoch_id, 'train', config, vdl_writer)
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
if config.get('use_ema'):
logger.info(logger.coloring("EMA validate start..."))
with ema.apply(exe):
top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog,
valid_fetchs, epoch_id, 'valid')
top1_acc = program.run(
valid_dataloader, exe, compiled_valid_prog,
valid_fetchs, epoch_id, 'valid', config)
logger.info(logger.coloring("EMA validate over!"))
top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs,
epoch_id, 'valid')
epoch_id, 'valid', config)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册