From d6e259bbe99f82be49685197f7d12fd57132daee Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Thu, 6 Jul 2023 20:04:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9COPY-FROM=20No.=203=20autogra?= =?UTF-8?q?d=20(#54921)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [autograd] add copy-from; test=document_fix * [autograd] add copy-from; test=document_fix * fix --- python/paddle/autograd/py_layer.py | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py index 08c4fb9ac8d..d2dd31f08dc 100644 --- a/python/paddle/autograd/py_layer.py +++ b/python/paddle/autograd/py_layer.py @@ -29,6 +29,27 @@ def with_mateclass(meta, *bases): class PyLayerContext: """ ``PyLayerContext`` can assist the :ref:`api_paddle_autograd_PyLayer` in implementing certain functionalities. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + # ctx is a object of PyLayerContext. + y = paddle.tanh(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # ctx is a object of PyLayerContext. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad """ def save_for_backward(self, *tensors): @@ -266,6 +287,33 @@ class PyLayer(with_mateclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)): After building the custom operator, apply it by running the ``apply`` method. + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + + data = paddle.randn([2, 3], dtype="float64") + data.stop_gradient = False + z = cus_tanh.apply(data) + z.mean().backward() + + print(data.grad) """ @staticmethod -- GitLab