diff --git a/tools/train.py b/tools/train.py index c7ae73203c5a6d9cccbc939ce8ac98859ec92545..607f4c698f7ac8964d9e813214ae088266d309a5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -133,7 +133,7 @@ def main(args): top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, - epoch_id, 'valid') + epoch_id, 'valid', config) if top1_acc > best_top1_acc: best_top1_acc = top1_acc message = "The best top1 acc {:.5f}, in epoch: {:d}".format(