提交 68c9f6ef 编写于 作者: W wanghaoshuang

Fix error while params_grads[1]==None

上级 e01c770c
......@@ -649,20 +649,23 @@ class ModelAverage(Optimizer):
self.min_average_window = min_average_window
self.max_average_window = max_average_window
self.params_grads = params_grads
for param, _ in self.params_grads:
self._append_average_accumulate_op(param)
for param, grad in self.params_grads:
if grad is not None:
self._append_average_accumulate_op(param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
for param_grad in self.params_grads:
self._add_average_apply_op(block, param_grad)
if param_grad[1] is not None:
self._add_average_apply_op(block, param_grad)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param_grad in self.params_grads:
self._add_average_restore_op(block, param_grad)
if param_grad[1] is not None:
self._add_average_restore_op(block, param_grad)
def _add_average_apply_op(self, block, param_grad):
param = block.clone_variable(param_grad[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册