未验证 提交 e4e5bad6 编写于 作者: L lujun 提交者: GitHub

Merge pull request #16908 from velconia/local_rel_1_4_dygraph_untrack_op

imperative fix train mode
...@@ -49,10 +49,10 @@ class Layer(core.Layer): ...@@ -49,10 +49,10 @@ class Layer(core.Layer):
self._helper = LayerObjectHelper(self._full_name) self._helper = LayerObjectHelper(self._full_name)
def train(self): def train(self):
framework._dygraph_tracer()._train_mode() framework._dygraph_tracer().train_mode()
def eval(self): def eval(self):
framework._dygraph_tracer()._eval_mode() framework._dygraph_tracer().eval_mode()
def full_name(self): def full_name(self):
"""Full name for this layers. """Full name for this layers.
...@@ -261,10 +261,10 @@ class PyLayer(core.PyLayer): ...@@ -261,10 +261,10 @@ class PyLayer(core.PyLayer):
super(PyLayer, self).__init__() super(PyLayer, self).__init__()
def train(self): def train(self):
framework._dygraph_tracer()._train_mode() framework._dygraph_tracer().train_mode()
def eval(self): def eval(self):
framework._dygraph_tracer()._eval_mode() framework._dygraph_tracer().eval_mode()
@classmethod @classmethod
def _do_forward(cls, inputs): def _do_forward(cls, inputs):
......
...@@ -118,8 +118,8 @@ class Tracer(core.Tracer): ...@@ -118,8 +118,8 @@ class Tracer(core.Tracer):
if k in backward_refs: if k in backward_refs:
op.backward_refs[k] = outputs[k] op.backward_refs[k] = outputs[k]
def _train_mode(self): def train_mode(self):
self._train_mode = True self._train_mode = True
def _eval_mode(self): def eval_mode(self):
self._train_mode = False self._train_mode = False
...@@ -117,6 +117,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -117,6 +117,7 @@ class TestImperativeMnist(unittest.TestCase):
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True) paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
mnist.train()
dy_param_init_value = {} dy_param_init_value = {}
for epoch in range(epoch_num): for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册