提交 cad4d7f3 编写于 作者: W wanghaoshuang

Refine initial and API of ModelAverage API

1. Implement 'with model_average.apply()' syntax
2. Init apply_program and restore_program in __init__ functin of ModelAverage
上级 92a01d49
......@@ -23,6 +23,7 @@ from initializer import Constant
from layer_helper import LayerHelper
from regularizer import append_regularization_ops
from clip import append_gradient_clip_ops, error_clip_callback
from contextlib import contextmanager
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad',
......@@ -631,10 +632,10 @@ class ModelAverage(Optimizer):
for pass_id in range(args.pass_num):
for data in train_reader():
exe.run(fluid.default_main_program()...)
model_average.apply()
with model_average.apply(exe):
for data in test_reader():
exe.run(inference_program...)
model_average.restore(exe)
"""
def __init__(self,
......@@ -651,6 +652,18 @@ class ModelAverage(Optimizer):
for param, _ in self.params_grads:
self._append_average_accumulate_op(param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
for param_grad in self.params_grads:
self._add_average_apply_op(block, param_grad)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param_grad in self.params_grads:
self._add_average_restore_op(block, param_grad)
def _add_average_apply_op(self, block, param_grad):
param = block.clone_variable(param_grad[0])
grad = block.clone_variable(param_grad[1])
......@@ -714,22 +727,20 @@ class ModelAverage(Optimizer):
"max_average_window": self.max_average_window,
})
def apply(self, executor):
@contextmanager
def apply(self, executor, need_restore=True):
"""Apply average values to parameters of current model.
"""
apply_program = Program()
block = apply_program.global_block()
with program_guard(main_program=apply_program):
for param_grad in self.params_grads:
self._add_average_apply_op(block, param_grad)
executor.run(apply_program)
executor.run(self.apply_program)
print "finish apply"
try:
yield
finally:
if need_restore:
self.restore(executor)
def restore(self, executor):
"""Restore parameter values of current model.
"""
restore_program = Program()
block = restore_program.global_block()
with program_guard(main_program=restore_program):
for param_grad in self.params_grads:
self._add_average_restore_op(block, param_grad)
executor.run(restore_program)
executor.run(self.restore_program)
print "finish restore"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册