提交 e2082162 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add config

上级 81a8e913
......@@ -68,7 +68,7 @@ def main(args):
compiled_valid_prog = program.compile(config, valid_prog)
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
'eval')
'eval', config)
if __name__ == '__main__':
......
......@@ -119,7 +119,7 @@ def main(args):
for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train', vdl_writer)
epoch_id, 'train', config, vdl_writer)
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
......@@ -128,11 +128,11 @@ def main(args):
with ema.apply(exe):
top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs,
epoch_id, 'valid')
epoch_id, 'valid', config)
logger.info(logger.coloring("EMA validate over!"))
top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog,
valid_fetchs, epoch_id, 'valid')
valid_fetchs, epoch_id, 'valid', config)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册