提交 8069b961 编写于 作者: Y Yang Zhang

Workaround the optimizer learning rate map issue

optimizer use `default_main_program()` as key for learning rate map, `Program`
seem to rely on object hash for equality check, and cloned program is not equal
to the original
上级 ccc1213e
......@@ -78,7 +78,7 @@ class StaticGraphAdapter(object):
# with `_build_once` gone, parameters are now created in `__init__`
# so we need to keep track of the parameters already created
self._startup_prog = fluid.default_startup_program()
self._main_prog = fluid.default_main_program()
self._orig_prog = fluid.default_main_program()
self._label_vars = None # label variables
self._endpoints = {}
......@@ -241,7 +241,12 @@ class StaticGraphAdapter(object):
return out[:num_output], out[num_output:]
def _make_program(self, inputs):
prog = self._main_prog.clone(self.mode != 'train')
prog = self._orig_prog.clone(for_test=self.mode != 'train')
if self.mode == 'train':
# HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var
losses = []
with fluid.program_guard(prog, self._startup_prog):
outputs = to_list(self.model.forward(*inputs))
losses = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册