diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py index 08c4fb9ac8d5ddb38ef1b8dc0907f4b76b95a133..d2dd31f08dcac41355996d39334129e5902a15f2 100644 --- a/python/paddle/autograd/py_layer.py +++ b/python/paddle/autograd/py_layer.py @@ -29,6 +29,27 @@ def with_mateclass(meta, *bases): class PyLayerContext: """ ``PyLayerContext`` can assist the :ref:`api_paddle_autograd_PyLayer` in implementing certain functionalities. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + # ctx is a object of PyLayerContext. + y = paddle.tanh(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # ctx is a object of PyLayerContext. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad """ def save_for_backward(self, *tensors): @@ -266,6 +287,33 @@ class PyLayer(with_mateclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)): After building the custom operator, apply it by running the ``apply`` method. + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + + data = paddle.randn([2, 3], dtype="float64") + data.stop_gradient = False + z = cus_tanh.apply(data) + z.mean().backward() + + print(data.grad) """ @staticmethod