diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 654627af445e33cc040935106b7f2316479cb4e1..7a833c82b9dfb4733237498e47e349288dba6b3a 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -86,6 +86,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable from . import install_check from .dygraph.nn import * from .dygraph.layers import * +from .dygraph.base import enable_dygraph, disable_dygraph from .io import save, load, load_program_state, set_program_state from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.varbase_patch_methods import monkey_patch_varbase @@ -103,6 +104,8 @@ __all__ = framework.__all__ + executor.__all__ + \ 'contrib', 'data', 'dygraph', + 'enable_dygraph', + 'disable_dygraph', 'transpiler', 'nets', 'optimizer', diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index a59fad7e5567e2afe433530c83f2e18b383205d6..e0460eadcb2950508fcca64529a2253112805e4b 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -13,6 +13,7 @@ # limitations under the License. from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator import contextlib +import sys import numpy as np from paddle.fluid import core from paddle.fluid import framework @@ -23,6 +24,9 @@ import objgraph __all__ = [ 'no_grad', 'guard', + 'enable_dygraph', + 'disable_dygraph', + 'enabled', 'to_variable', ] @@ -49,12 +53,85 @@ def program_desc_tracing_guard(enable): tracer._enable_program_desc_tracing = original_val -# This function should be removed in V1.6, because it can easily lead to cyclic dependencies. +_functional_dygraph_context_manager = None + + def enabled(): - # Internal use only + """ + This function checks whether the program runs in dynamic graph mode or not. + You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api, + or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable` + and :ref:`api_fluid_dygraph_disable` api . + + **Note**: + ``fluid.dygraph.enabled`` is the alias of ``fluid.in_dygraph_mode``, and + ``fluid.in_dygraph_mode`` is recommended to use. + + Returns: + bool: Whether the program is running in dynamic graph mode. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + fluid.enable_dygraph() # Now we are in dygragh mode + print(fluid.dygraph.enabled()) # True + fluid.disable_dygraph() + print(fluid.dygraph.enabled()) # False + """ return framework.in_dygraph_mode() +def enable_dygraph(place=None): + """ + This function enables dynamic graph mode. + + Parameters: + place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place to execute dygraph. + If None, the running place will be determined according to the way of paddle compilation. Default: None + + return: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + fluid.enable_dygraph() # Now we are in dygragh mode + print(fluid.in_dygraph_mode()) # True + fluid.disable_dygraph() + print(fluid.in_dygraph_mode()) # False + """ + global _functional_dygraph_context_manager + _functional_dygraph_context_manager = guard(place=place) + _functional_dygraph_context_manager.__enter__() + + +def disable_dygraph(): + """ + This function disables dynamic graph mode. + + return: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + fluid.enable_dygraph() # Now we are in dygragh mode + print(fluid.in_dygraph_mode()) # True + fluid.disable_dygraph() + print(fluid.in_dygraph_mode()) # False + """ + global _functional_dygraph_context_manager + if _functional_dygraph_context_manager is not None: + _functional_dygraph_context_manager.__exit__(*sys.exc_info()) + _functional_dygraph_context_manager = None + + @contextlib.contextmanager def _switch_tracer_mode_guard_(is_train=True): tracer = framework._dygraph_tracer() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ebd02101c2bc4861563247777636db98e85af70a..38408d2b9a9d42ed23291696391f6ac3f6595a4b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -173,7 +173,9 @@ def require_version(min_version, max_version=None): def in_dygraph_mode(): """ This function checks whether the program runs in dynamic graph mode or not. - You can turn on dynamic graph mode with :ref:`api_fluid_dygraph_guard` api. + You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api, + or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable` + and :ref:`api_fluid_dygraph_disable` api . Returns: bool: Whether the program is running in dynamic graph mode. @@ -182,11 +184,11 @@ def in_dygraph_mode(): .. code-block:: python import paddle.fluid as fluid - if fluid.in_dygraph_mode(): - print('running in dygraph mode') - else: - print('not running in dygraph mode') + fluid.enable_dygraph() # Now we are in dygragh mode + print(fluid.in_dygraph_mode()) # True + fluid.disable_dygraph() + print(fluid.in_dygraph_mode()) # False """ return _dygraph_tracer_ is not None diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 944d7bd3a85ad54020e71ec81dd6f4829a65bf92..b43e284d8c49a4c154b0e581ddc5d2fc52189c3c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -177,6 +177,31 @@ class SimpleRNN(fluid.Layer): class TestImperative(unittest.TestCase): + def test_functional_dygraph_context(self): + self.assertFalse(fluid.dygraph.enabled()) + fluid.enable_dygraph() + self.assertTrue(fluid.dygraph.enabled()) + np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + var_inp = fluid.dygraph.base.to_variable(np_inp) + mlp = MLP(input_size=2) + out = mlp(var_inp) + dy_out1 = out.numpy() + out.backward() + dy_grad1 = mlp._linear1.weight.gradient() + fluid.disable_dygraph() + self.assertFalse(fluid.dygraph.enabled()) + with fluid.dygraph.guard(): + self.assertTrue(fluid.dygraph.enabled()) + var_inp = fluid.dygraph.base.to_variable(np_inp) + mlp = MLP(input_size=2) + out = mlp(var_inp) + dy_out2 = out.numpy() + out.backward() + dy_grad2 = mlp._linear1.weight.gradient() + self.assertFalse(fluid.dygraph.enabled()) + self.assertTrue(np.array_equal(dy_out1, dy_out2)) + self.assertTrue(np.array_equal(dy_grad1, dy_grad2)) + def test_isinstance(self): var = fluid.layers.data(shape=[1], name='x', dtype='float32') self.assertTrue(isinstance(var, fluid.Variable))