diff --git a/python/paddle/framework/framework.py b/python/paddle/framework/framework.py index 17eaa82cd8b6a049ce99396fe0aaf2ab0476a182..d5fa45f76884f2d32644625e85fa5006963121e5 100644 --- a/python/paddle/framework/framework.py +++ b/python/paddle/framework/framework.py @@ -87,8 +87,6 @@ def get_default_dtype(): @contextmanager def set_grad_enabled(mode): """ - :api_attr: imperative - Create a context which enables or disables dygraph gradient calculation. Args: @@ -96,11 +94,13 @@ def set_grad_enabled(mode): Examples: .. code-block:: python + + import paddle x = paddle.ones([3, 2]) x.stop_gradient = False - with torch.set_grad_enabled(False): + with paddle.set_grad_enabled(False): y = x * 2 - with torch.set_grad_enabled(True): + with paddle.set_grad_enabled(True): z = x * 2 print(y.stop_gradient) # True print(z.stop_gradient) # False