提交 cad4d7f3 编写于 作者: W wanghaoshuang

Refine initial and API of ModelAverage API

1. Implement 'with model_average.apply()' syntax
2. Init apply_program and restore_program in __init__ functin of ModelAverage
上级 92a01d49
...@@ -23,6 +23,7 @@ from initializer import Constant ...@@ -23,6 +23,7 @@ from initializer import Constant
from layer_helper import LayerHelper from layer_helper import LayerHelper
from regularizer import append_regularization_ops from regularizer import append_regularization_ops
from clip import append_gradient_clip_ops, error_clip_callback from clip import append_gradient_clip_ops, error_clip_callback
from contextlib import contextmanager
__all__ = [ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad',
...@@ -631,10 +632,10 @@ class ModelAverage(Optimizer): ...@@ -631,10 +632,10 @@ class ModelAverage(Optimizer):
for pass_id in range(args.pass_num): for pass_id in range(args.pass_num):
for data in train_reader(): for data in train_reader():
exe.run(fluid.default_main_program()...) exe.run(fluid.default_main_program()...)
model_average.apply()
for data in test_reader(): with model_average.apply(exe):
exe.run(inference_program...) for data in test_reader():
model_average.restore(exe) exe.run(inference_program...)
""" """
def __init__(self, def __init__(self,
...@@ -651,6 +652,18 @@ class ModelAverage(Optimizer): ...@@ -651,6 +652,18 @@ class ModelAverage(Optimizer):
for param, _ in self.params_grads: for param, _ in self.params_grads:
self._append_average_accumulate_op(param) self._append_average_accumulate_op(param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
for param_grad in self.params_grads:
self._add_average_apply_op(block, param_grad)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param_grad in self.params_grads:
self._add_average_restore_op(block, param_grad)
def _add_average_apply_op(self, block, param_grad): def _add_average_apply_op(self, block, param_grad):
param = block.clone_variable(param_grad[0]) param = block.clone_variable(param_grad[0])
grad = block.clone_variable(param_grad[1]) grad = block.clone_variable(param_grad[1])
...@@ -714,22 +727,20 @@ class ModelAverage(Optimizer): ...@@ -714,22 +727,20 @@ class ModelAverage(Optimizer):
"max_average_window": self.max_average_window, "max_average_window": self.max_average_window,
}) })
def apply(self, executor): @contextmanager
def apply(self, executor, need_restore=True):
"""Apply average values to parameters of current model. """Apply average values to parameters of current model.
""" """
apply_program = Program() executor.run(self.apply_program)
block = apply_program.global_block() print "finish apply"
with program_guard(main_program=apply_program): try:
for param_grad in self.params_grads: yield
self._add_average_apply_op(block, param_grad) finally:
executor.run(apply_program) if need_restore:
self.restore(executor)
def restore(self, executor): def restore(self, executor):
"""Restore parameter values of current model. """Restore parameter values of current model.
""" """
restore_program = Program() executor.run(self.restore_program)
block = restore_program.global_block() print "finish restore"
with program_guard(main_program=restore_program):
for param_grad in self.params_grads:
self._add_average_restore_op(block, param_grad)
executor.run(restore_program)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册