未验证 提交 df87e79f 编写于 作者: S songyouwei 提交者: GitHub

Add functional dygraph mode api (#22745)

* functional dygraph enable/disable
test=develop

* use context manager instead
test=develop

* refine sample code
test=develop

* rename api & expose to fluid
test=develop

* fix sample code
test=develop
上级 0f9d4081
...@@ -86,6 +86,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable ...@@ -86,6 +86,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable
from . import install_check from . import install_check
from .dygraph.nn import * from .dygraph.nn import *
from .dygraph.layers import * from .dygraph.layers import *
from .dygraph.base import enable_dygraph, disable_dygraph
from .io import save, load, load_program_state, set_program_state from .io import save, load, load_program_state, set_program_state
from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.checkpoint import save_dygraph, load_dygraph
from .dygraph.varbase_patch_methods import monkey_patch_varbase from .dygraph.varbase_patch_methods import monkey_patch_varbase
...@@ -103,6 +104,8 @@ __all__ = framework.__all__ + executor.__all__ + \ ...@@ -103,6 +104,8 @@ __all__ = framework.__all__ + executor.__all__ + \
'contrib', 'contrib',
'data', 'data',
'dygraph', 'dygraph',
'enable_dygraph',
'disable_dygraph',
'transpiler', 'transpiler',
'nets', 'nets',
'optimizer', 'optimizer',
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib import contextlib
import sys
import numpy as np import numpy as np
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
...@@ -23,6 +24,9 @@ import objgraph ...@@ -23,6 +24,9 @@ import objgraph
__all__ = [ __all__ = [
'no_grad', 'no_grad',
'guard', 'guard',
'enable_dygraph',
'disable_dygraph',
'enabled',
'to_variable', 'to_variable',
] ]
...@@ -49,12 +53,85 @@ def program_desc_tracing_guard(enable): ...@@ -49,12 +53,85 @@ def program_desc_tracing_guard(enable):
tracer._enable_program_desc_tracing = original_val tracer._enable_program_desc_tracing = original_val
# This function should be removed in V1.6, because it can easily lead to cyclic dependencies. _functional_dygraph_context_manager = None
def enabled(): def enabled():
# Internal use only """
This function checks whether the program runs in dynamic graph mode or not.
You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable`
and :ref:`api_fluid_dygraph_disable` api .
**Note**:
``fluid.dygraph.enabled`` is the alias of ``fluid.in_dygraph_mode``, and
``fluid.in_dygraph_mode`` is recommended to use.
Returns:
bool: Whether the program is running in dynamic graph mode.
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.dygraph.enabled()) # True
fluid.disable_dygraph()
print(fluid.dygraph.enabled()) # False
"""
return framework.in_dygraph_mode() return framework.in_dygraph_mode()
def enable_dygraph(place=None):
"""
This function enables dynamic graph mode.
Parameters:
place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place to execute dygraph.
If None, the running place will be determined according to the way of paddle compilation. Default: None
return:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.in_dygraph_mode()) # True
fluid.disable_dygraph()
print(fluid.in_dygraph_mode()) # False
"""
global _functional_dygraph_context_manager
_functional_dygraph_context_manager = guard(place=place)
_functional_dygraph_context_manager.__enter__()
def disable_dygraph():
"""
This function disables dynamic graph mode.
return:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.in_dygraph_mode()) # True
fluid.disable_dygraph()
print(fluid.in_dygraph_mode()) # False
"""
global _functional_dygraph_context_manager
if _functional_dygraph_context_manager is not None:
_functional_dygraph_context_manager.__exit__(*sys.exc_info())
_functional_dygraph_context_manager = None
@contextlib.contextmanager @contextlib.contextmanager
def _switch_tracer_mode_guard_(is_train=True): def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
......
...@@ -173,7 +173,9 @@ def require_version(min_version, max_version=None): ...@@ -173,7 +173,9 @@ def require_version(min_version, max_version=None):
def in_dygraph_mode(): def in_dygraph_mode():
""" """
This function checks whether the program runs in dynamic graph mode or not. This function checks whether the program runs in dynamic graph mode or not.
You can turn on dynamic graph mode with :ref:`api_fluid_dygraph_guard` api. You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable`
and :ref:`api_fluid_dygraph_disable` api .
Returns: Returns:
bool: Whether the program is running in dynamic graph mode. bool: Whether the program is running in dynamic graph mode.
...@@ -182,11 +184,11 @@ def in_dygraph_mode(): ...@@ -182,11 +184,11 @@ def in_dygraph_mode():
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
if fluid.in_dygraph_mode():
print('running in dygraph mode')
else:
print('not running in dygraph mode')
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.in_dygraph_mode()) # True
fluid.disable_dygraph()
print(fluid.in_dygraph_mode()) # False
""" """
return _dygraph_tracer_ is not None return _dygraph_tracer_ is not None
......
...@@ -177,6 +177,31 @@ class SimpleRNN(fluid.Layer): ...@@ -177,6 +177,31 @@ class SimpleRNN(fluid.Layer):
class TestImperative(unittest.TestCase): class TestImperative(unittest.TestCase):
def test_functional_dygraph_context(self):
self.assertFalse(fluid.dygraph.enabled())
fluid.enable_dygraph()
self.assertTrue(fluid.dygraph.enabled())
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
var_inp = fluid.dygraph.base.to_variable(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out1 = out.numpy()
out.backward()
dy_grad1 = mlp._linear1.weight.gradient()
fluid.disable_dygraph()
self.assertFalse(fluid.dygraph.enabled())
with fluid.dygraph.guard():
self.assertTrue(fluid.dygraph.enabled())
var_inp = fluid.dygraph.base.to_variable(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out2 = out.numpy()
out.backward()
dy_grad2 = mlp._linear1.weight.gradient()
self.assertFalse(fluid.dygraph.enabled())
self.assertTrue(np.array_equal(dy_out1, dy_out2))
self.assertTrue(np.array_equal(dy_grad1, dy_grad2))
def test_isinstance(self): def test_isinstance(self):
var = fluid.layers.data(shape=[1], name='x', dtype='float32') var = fluid.layers.data(shape=[1], name='x', dtype='float32')
self.assertTrue(isinstance(var, fluid.Variable)) self.assertTrue(isinstance(var, fluid.Variable))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册