From 311c92a449692ba064fa511ace83ad8300c61972 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 19 Mar 2018 18:50:07 +0800 Subject: [PATCH] Add syntax 'with average_model.apply(exe)' --- fluid/ocr_recognition/ctc_train.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/fluid/ocr_recognition/ctc_train.py b/fluid/ocr_recognition/ctc_train.py index 922f70f0..4a68ebdd 100644 --- a/fluid/ocr_recognition/ctc_train.py +++ b/fluid/ocr_recognition/ctc_train.py @@ -81,17 +81,16 @@ def train(args, data_reader=dummy_reader): sys.stdout.flush() batch_id += 1 - if model_average != None: - model_average.apply(exe) - error_evaluator.reset(exe) - for data in test_reader(): - exe.run(inference_program, feed=get_feeder_data(data, place)) - _, test_seq_error = error_evaluator.eval(exe) - if model_average != None: - model_average.restore(exe) + with model_average.apply(exe): + error_evaluator.reset(exe) + for data in test_reader(): + exe.run(inference_program, feed=get_feeder_data(data, place)) + _, test_seq_error = error_evaluator.eval(exe) + if model_average != None: + model_average.restore(exe) - print "\nEnd pass[%d]; Test seq error: %s.\n" % ( - pass_id, str(test_seq_error[0])) + print "\nEnd pass[%d]; Test seq error: %s.\n" % ( + pass_id, str(test_seq_error[0])) def main(): args = parser.parse_args() -- GitLab