未验证 提交 df87e79f 编写于 作者: S songyouwei 提交者: GitHub

Add functional dygraph mode api (#22745)

* functional dygraph enable/disable
test=develop

* use context manager instead
test=develop

* refine sample code
test=develop

* rename api & expose to fluid
test=develop

* fix sample code
test=develop
上级 0f9d4081
......@@ -86,6 +86,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable
from . import install_check
from .dygraph.nn import *
from .dygraph.layers import *
from .dygraph.base import enable_dygraph, disable_dygraph
from .io import save, load, load_program_state, set_program_state
from .dygraph.checkpoint import save_dygraph, load_dygraph
from .dygraph.varbase_patch_methods import monkey_patch_varbase
......@@ -103,6 +104,8 @@ __all__ = framework.__all__ + executor.__all__ + \
'contrib',
'data',
'dygraph',
'enable_dygraph',
'disable_dygraph',
'transpiler',
'nets',
'optimizer',
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib
import sys
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
......@@ -23,6 +24,9 @@ import objgraph
__all__ = [
'no_grad',
'guard',
'enable_dygraph',
'disable_dygraph',
'enabled',
'to_variable',
]
......@@ -49,12 +53,85 @@ def program_desc_tracing_guard(enable):
tracer._enable_program_desc_tracing = original_val
# This function should be removed in V1.6, because it can easily lead to cyclic dependencies.
_functional_dygraph_context_manager = None
def enabled():
# Internal use only
"""
This function checks whether the program runs in dynamic graph mode or not.
You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable`
and :ref:`api_fluid_dygraph_disable` api .
**Note**:
``fluid.dygraph.enabled`` is the alias of ``fluid.in_dygraph_mode``, and
``fluid.in_dygraph_mode`` is recommended to use.
Returns:
bool: Whether the program is running in dynamic graph mode.
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.dygraph.enabled()) # True
fluid.disable_dygraph()
print(fluid.dygraph.enabled()) # False
"""
return framework.in_dygraph_mode()
def enable_dygraph(place=None):
"""
This function enables dynamic graph mode.
Parameters:
place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place to execute dygraph.
If None, the running place will be determined according to the way of paddle compilation. Default: None
return:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.in_dygraph_mode()) # True
fluid.disable_dygraph()
print(fluid.in_dygraph_mode()) # False
"""
global _functional_dygraph_context_manager
_functional_dygraph_context_manager = guard(place=place)
_functional_dygraph_context_manager.__enter__()
def disable_dygraph():
"""
This function disables dynamic graph mode.
return:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.in_dygraph_mode()) # True
fluid.disable_dygraph()
print(fluid.in_dygraph_mode()) # False
"""
global _functional_dygraph_context_manager
if _functional_dygraph_context_manager is not None:
_functional_dygraph_context_manager.__exit__(*sys.exc_info())
_functional_dygraph_context_manager = None
@contextlib.contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
......
......@@ -173,7 +173,9 @@ def require_version(min_version, max_version=None):
def in_dygraph_mode():
"""
This function checks whether the program runs in dynamic graph mode or not.
You can turn on dynamic graph mode with :ref:`api_fluid_dygraph_guard` api.
You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable`
and :ref:`api_fluid_dygraph_disable` api .
Returns:
bool: Whether the program is running in dynamic graph mode.
......@@ -182,11 +184,11 @@ def in_dygraph_mode():
.. code-block:: python
import paddle.fluid as fluid
if fluid.in_dygraph_mode():
print('running in dygraph mode')
else:
print('not running in dygraph mode')
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.in_dygraph_mode()) # True
fluid.disable_dygraph()
print(fluid.in_dygraph_mode()) # False
"""
return _dygraph_tracer_ is not None
......
......@@ -177,6 +177,31 @@ class SimpleRNN(fluid.Layer):
class TestImperative(unittest.TestCase):
def test_functional_dygraph_context(self):
self.assertFalse(fluid.dygraph.enabled())
fluid.enable_dygraph()
self.assertTrue(fluid.dygraph.enabled())
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
var_inp = fluid.dygraph.base.to_variable(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out1 = out.numpy()
out.backward()
dy_grad1 = mlp._linear1.weight.gradient()
fluid.disable_dygraph()
self.assertFalse(fluid.dygraph.enabled())
with fluid.dygraph.guard():
self.assertTrue(fluid.dygraph.enabled())
var_inp = fluid.dygraph.base.to_variable(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out2 = out.numpy()
out.backward()
dy_grad2 = mlp._linear1.weight.gradient()
self.assertFalse(fluid.dygraph.enabled())
self.assertTrue(np.array_equal(dy_out1, dy_out2))
self.assertTrue(np.array_equal(dy_grad1, dy_grad2))
def test_isinstance(self):
var = fluid.layers.data(shape=[1], name='x', dtype='float32')
self.assertTrue(isinstance(var, fluid.Variable))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册