From 9fe93b154ae0ad920da8fa9d7350c1a27bc1faf0 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Wed, 13 May 2020 17:03:40 +0800 Subject: [PATCH] add enable_imperative, disable_imperative (#24387) * add enable_imperative, disable_imperative; test=develop * add unitest; test=develop * polish example; test=develop --- python/paddle/fluid/__init__.py | 4 +- python/paddle/fluid/dygraph/base.py | 50 +++++++++++++++++++ .../tests/unittests/test_imperative_basic.py | 14 ++++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 1be99cb4a7..6894b9d66e 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -84,7 +84,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable from . import install_check from .dygraph.nn import * from .dygraph.layers import * -from .dygraph.base import enable_dygraph, disable_dygraph +from .dygraph.base import enable_dygraph, disable_dygraph, enable_imperative, disable_imperative from .io import save, load, load_program_state, set_program_state from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.varbase_patch_methods import monkey_patch_varbase @@ -104,6 +104,8 @@ __all__ = framework.__all__ + executor.__all__ + \ 'dygraph', 'enable_dygraph', 'disable_dygraph', + 'enable_imperative', + 'disable_imperative', 'transpiler', 'nets', 'optimizer', diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 627376efa1..525563428b 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -29,6 +29,8 @@ __all__ = [ 'guard', 'enable_dygraph', 'disable_dygraph', + 'enable_imperative', + 'disable_imperative', 'enabled', 'to_variable', ] @@ -86,6 +88,54 @@ def enabled(): return framework.in_dygraph_mode() +def enable_imperative(place=None): + """ + This function enables imperative mode. + + Parameters: + place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place for imperative execution. + If None, the running place will be determined according to the way of paddle compilation. Default: None + + return: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + fluid.enable_imperative() # Now we are in imperative mode + x = fluid.layers.ones( (2, 2), "float32") + y = fluid.layers.zeros( (2, 2), "float32") + z = x + y + print( z.numpy() ) #[[1, 1], [1, 1]] + + """ + enable_dygraph(place) + + +def disable_imperative(): + """ + This function disables imperative mode. + + return: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + fluid.enable_imperative() # Now we are in imperative mode + x = fluid.layers.ones( (2, 2), "float32") + y = fluid.layers.zeros( (2, 2), "float32") + z = x + y + print( z.numpy() ) #[[1, 1], [1, 1]] + fluid.disable_imperative() # Now we are in declarative mode + """ + disable_dygraph() + + def enable_dygraph(place=None): """ This function enables dynamic graph mode. diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 47831341c4..a427557aae 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -203,6 +203,20 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.array_equal(dy_out1, dy_out2)) self.assertTrue(np.array_equal(dy_grad1, dy_grad2)) + def test_anable_imperative(self): + self.assertFalse(fluid.dygraph.enabled()) + fluid.enable_imperative() + self.assertTrue(fluid.dygraph.enabled()) + np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + var_inp = fluid.dygraph.base.to_variable(np_inp) + mlp = MLP(input_size=2) + out = mlp(var_inp) + dy_out1 = out.numpy() + out.backward() + dy_grad1 = mlp._linear1.weight.gradient() + fluid.disable_imperative() + self.assertFalse(fluid.dygraph.enabled()) + def test_isinstance(self): var = fluid.layers.data(shape=[1], name='x', dtype='float32') self.assertTrue(isinstance(var, fluid.Variable)) -- GitLab