diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 5473e61468adbc3567c1994feb90bed79a43dbaf..394cf050a7d3e2e88688167ee96ce1b9901dfef9 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -23,6 +23,7 @@ from initializer import Constant from layer_helper import LayerHelper from regularizer import append_regularization_ops from clip import append_gradient_clip_ops, error_clip_callback +from contextlib import contextmanager __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', @@ -631,10 +632,10 @@ class ModelAverage(Optimizer): for pass_id in range(args.pass_num): for data in train_reader(): exe.run(fluid.default_main_program()...) - model_average.apply() - for data in test_reader(): - exe.run(inference_program...) - model_average.restore(exe) + + with model_average.apply(exe): + for data in test_reader(): + exe.run(inference_program...) """ def __init__(self, @@ -651,6 +652,18 @@ class ModelAverage(Optimizer): for param, _ in self.params_grads: self._append_average_accumulate_op(param) + self.apply_program = Program() + block = self.apply_program.global_block() + with program_guard(main_program=self.apply_program): + for param_grad in self.params_grads: + self._add_average_apply_op(block, param_grad) + + self.restore_program = Program() + block = self.restore_program.global_block() + with program_guard(main_program=self.restore_program): + for param_grad in self.params_grads: + self._add_average_restore_op(block, param_grad) + def _add_average_apply_op(self, block, param_grad): param = block.clone_variable(param_grad[0]) grad = block.clone_variable(param_grad[1]) @@ -714,22 +727,20 @@ class ModelAverage(Optimizer): "max_average_window": self.max_average_window, }) - def apply(self, executor): + @contextmanager + def apply(self, executor, need_restore=True): """Apply average values to parameters of current model. """ - apply_program = Program() - block = apply_program.global_block() - with program_guard(main_program=apply_program): - for param_grad in self.params_grads: - self._add_average_apply_op(block, param_grad) - executor.run(apply_program) + executor.run(self.apply_program) + print "finish apply" + try: + yield + finally: + if need_restore: + self.restore(executor) def restore(self, executor): """Restore parameter values of current model. """ - restore_program = Program() - block = restore_program.global_block() - with program_guard(main_program=restore_program): - for param_grad in self.params_grads: - self._add_average_restore_op(block, param_grad) - executor.run(restore_program) + executor.run(self.restore_program) + print "finish restore"