diff --git a/fluid/ocr_recognition/ctc_train.py b/fluid/ocr_recognition/ctc_train.py index 922f70f05722fa14d322b2061900092fd2adb0dd..4a68ebdd2e5fd0b50c9085965ce9ebed7829281c 100644 --- a/fluid/ocr_recognition/ctc_train.py +++ b/fluid/ocr_recognition/ctc_train.py @@ -81,17 +81,16 @@ def train(args, data_reader=dummy_reader): sys.stdout.flush() batch_id += 1 - if model_average != None: - model_average.apply(exe) - error_evaluator.reset(exe) - for data in test_reader(): - exe.run(inference_program, feed=get_feeder_data(data, place)) - _, test_seq_error = error_evaluator.eval(exe) - if model_average != None: - model_average.restore(exe) + with model_average.apply(exe): + error_evaluator.reset(exe) + for data in test_reader(): + exe.run(inference_program, feed=get_feeder_data(data, place)) + _, test_seq_error = error_evaluator.eval(exe) + if model_average != None: + model_average.restore(exe) - print "\nEnd pass[%d]; Test seq error: %s.\n" % ( - pass_id, str(test_seq_error[0])) + print "\nEnd pass[%d]; Test seq error: %s.\n" % ( + pass_id, str(test_seq_error[0])) def main(): args = parser.parse_args()