提交 b4614894 编写于 作者: Y Yang Zhang

Safeguard the learning rate map hack

上级 4e9c403e
......@@ -280,7 +280,7 @@ class StaticGraphAdapter(object):
def _make_program(self, inputs):
prog = self._orig_prog.clone(for_test=self.mode != 'train')
if self.mode == 'train':
if self.mode == 'train' and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册