From 97aa1838bc0eb276a2b12589f09f1f04666065b6 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 16 Apr 2019 16:53:17 +0800 Subject: [PATCH] Fix dygraph train mode test=develop --- python/paddle/fluid/dygraph/layers.py | 8 ++++---- python/paddle/fluid/dygraph/tracer.py | 4 ++-- .../paddle/fluid/tests/unittests/test_imperative_mnist.py | 1 + 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 6b78e2abb..c772e5089 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -49,10 +49,10 @@ class Layer(core.Layer): self._helper = LayerObjectHelper(self._full_name) def train(self): - framework._dygraph_tracer()._train_mode() + framework._dygraph_tracer().train_mode() def eval(self): - framework._dygraph_tracer()._eval_mode() + framework._dygraph_tracer().eval_mode() def full_name(self): """Full name for this layers. @@ -261,10 +261,10 @@ class PyLayer(core.PyLayer): super(PyLayer, self).__init__() def train(self): - framework._dygraph_tracer()._train_mode() + framework._dygraph_tracer().train_mode() def eval(self): - framework._dygraph_tracer()._eval_mode() + framework._dygraph_tracer().eval_mode() @classmethod def _do_forward(cls, inputs): diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index ee37ffab2..9d2cbb4f0 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -118,8 +118,8 @@ class Tracer(core.Tracer): if k in backward_refs: op.backward_refs[k] = outputs[k] - def _train_mode(self): + def train_mode(self): self._train_mode = True - def _eval_mode(self): + def eval_mode(self): self._train_mode = False diff --git a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py index 76b8d3aa3..908237b88 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py @@ -117,6 +117,7 @@ class TestImperativeMnist(unittest.TestCase): train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=128, drop_last=True) + mnist.train() dy_param_init_value = {} for epoch in range(epoch_num): for batch_id, data in enumerate(train_reader()): -- GitLab