未验证 提交 9fe93b15 编写于 作者: H hong 提交者: GitHub

add enable_imperative, disable_imperative (#24387)

* add enable_imperative, disable_imperative; test=develop

* add unitest; test=develop

* polish example; test=develop
上级 d20c88c5
...@@ -84,7 +84,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable ...@@ -84,7 +84,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable
from . import install_check from . import install_check
from .dygraph.nn import * from .dygraph.nn import *
from .dygraph.layers import * from .dygraph.layers import *
from .dygraph.base import enable_dygraph, disable_dygraph from .dygraph.base import enable_dygraph, disable_dygraph, enable_imperative, disable_imperative
from .io import save, load, load_program_state, set_program_state from .io import save, load, load_program_state, set_program_state
from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.checkpoint import save_dygraph, load_dygraph
from .dygraph.varbase_patch_methods import monkey_patch_varbase from .dygraph.varbase_patch_methods import monkey_patch_varbase
...@@ -104,6 +104,8 @@ __all__ = framework.__all__ + executor.__all__ + \ ...@@ -104,6 +104,8 @@ __all__ = framework.__all__ + executor.__all__ + \
'dygraph', 'dygraph',
'enable_dygraph', 'enable_dygraph',
'disable_dygraph', 'disable_dygraph',
'enable_imperative',
'disable_imperative',
'transpiler', 'transpiler',
'nets', 'nets',
'optimizer', 'optimizer',
......
...@@ -29,6 +29,8 @@ __all__ = [ ...@@ -29,6 +29,8 @@ __all__ = [
'guard', 'guard',
'enable_dygraph', 'enable_dygraph',
'disable_dygraph', 'disable_dygraph',
'enable_imperative',
'disable_imperative',
'enabled', 'enabled',
'to_variable', 'to_variable',
] ]
...@@ -86,6 +88,54 @@ def enabled(): ...@@ -86,6 +88,54 @@ def enabled():
return framework.in_dygraph_mode() return framework.in_dygraph_mode()
def enable_imperative(place=None):
"""
This function enables imperative mode.
Parameters:
place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place for imperative execution.
If None, the running place will be determined according to the way of paddle compilation. Default: None
return:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_imperative() # Now we are in imperative mode
x = fluid.layers.ones( (2, 2), "float32")
y = fluid.layers.zeros( (2, 2), "float32")
z = x + y
print( z.numpy() ) #[[1, 1], [1, 1]]
"""
enable_dygraph(place)
def disable_imperative():
"""
This function disables imperative mode.
return:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_imperative() # Now we are in imperative mode
x = fluid.layers.ones( (2, 2), "float32")
y = fluid.layers.zeros( (2, 2), "float32")
z = x + y
print( z.numpy() ) #[[1, 1], [1, 1]]
fluid.disable_imperative() # Now we are in declarative mode
"""
disable_dygraph()
def enable_dygraph(place=None): def enable_dygraph(place=None):
""" """
This function enables dynamic graph mode. This function enables dynamic graph mode.
......
...@@ -203,6 +203,20 @@ class TestImperative(unittest.TestCase): ...@@ -203,6 +203,20 @@ class TestImperative(unittest.TestCase):
self.assertTrue(np.array_equal(dy_out1, dy_out2)) self.assertTrue(np.array_equal(dy_out1, dy_out2))
self.assertTrue(np.array_equal(dy_grad1, dy_grad2)) self.assertTrue(np.array_equal(dy_grad1, dy_grad2))
def test_anable_imperative(self):
self.assertFalse(fluid.dygraph.enabled())
fluid.enable_imperative()
self.assertTrue(fluid.dygraph.enabled())
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
var_inp = fluid.dygraph.base.to_variable(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out1 = out.numpy()
out.backward()
dy_grad1 = mlp._linear1.weight.gradient()
fluid.disable_imperative()
self.assertFalse(fluid.dygraph.enabled())
def test_isinstance(self): def test_isinstance(self):
var = fluid.layers.data(shape=[1], name='x', dtype='float32') var = fluid.layers.data(shape=[1], name='x', dtype='float32')
self.assertTrue(isinstance(var, fluid.Variable)) self.assertTrue(isinstance(var, fluid.Variable))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册