py_layer.py 12.5 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import core
17

18
__all__ = []
19 20 21 22 23 24 25 26 27 28


def with_mateclass(meta, *bases):
    class impl(meta):
        def __new__(cls, name, temp_bases, attrs):
            return meta(name, bases, attrs)

    return type.__new__(impl, "impl", (), {})


29
class PyLayerContext:
H
hg-1099255210 已提交
30 31 32 33
    """
    ``PyLayerContext`` can assist the :ref:`api_paddle_autograd_PyLayer` in implementing certain functionalities.
    """

W
wanghuancoder 已提交
34 35 36
    def save_for_backward(self, *tensors):
        """
        Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
37

38
        Note:
39
            This API should be called at most once, and only inside `forward`.
W
wanghuancoder 已提交
40 41 42 43 44 45

        Args:
            tensors(list of Tensors): Tensors to be stored.

        Returns:
            None
46

W
wanghuancoder 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad

        """
        self.container = tensors

    def saved_tensor(self):
        """
        Get the tensors stored by ``save_for_backward``.

        Returns:
77
            list of Tensors or None: If context contains tensors stored by `save_for_backward`,
W
wanghuancoder 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
            then return these tensors, otherwise return None.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """
        return self.container

104 105 106
    def mark_not_inplace(self, *args):
        """
        Marks inputs as not inplace.
107
        This should be called at most once, only from inside the `forward` method,
108 109
        and all arguments should be Tensor inputs.

110 111
        If the Tensor returned by `forward` method is the same as the Tensor input of forward,
        and this Tensor is marked as not_inplace, then Paddle will help the user create a new Tensor as output.
112 113 114 115 116 117 118 119 120 121 122 123
        Thereby preventing the auto grad information of the input Tensor from being overwritten.

        Examples:
            .. code-block:: python

                import paddle

                class Exp(paddle.autograd.PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        ctx.mark_not_inplace(x)
                        return x
124

125 126 127 128 129 130 131 132 133 134
                    @staticmethod
                    def backward(ctx, grad_output):
                        out = grad_output.exp()
                        return out

                x = paddle.randn((1, 1))
                x.stop_gradient = False
                attn_layers = []
                for idx in range(0, 2):
                    attn_layers.append(Exp())
135

136 137 138 139 140 141 142
                for step in range(0, 2):
                    a = x
                    for j in range(0,2):
                        a = attn_layers[j].apply(x)
                    a.backward()
        """
        self.not_inplace_tensors = args
W
wanghuancoder 已提交
143 144 145 146

    def mark_non_differentiable(self, *args):
        """
        Marks outputs as non-differentiable.
147
        This should be called at most once, only from inside the `forward` method,
W
wanghuancoder 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
        and all arguments should be tensor outputs.

        This will mark outputs as not requiring gradients, increasing the
        efficiency of backward computation. You still need to accept a gradient
        for each output in `backward`, but it's always going to
        be a zero tensor with the same shape as the shape of a corresponding
        output.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer
                import numpy as np

                class Tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        a = x + x
                        b = x + x + x
                        ctx.mark_non_differentiable(a)
                        return a, b

                    @staticmethod
                    def backward(ctx, grad_a, grad_b):
                        assert np.equal(grad_a.numpy(), paddle.zeros([1]).numpy())
                        assert np.equal(grad_b.numpy(), paddle.ones([1], dtype="float64").numpy())
                        return grad_b

                x = paddle.ones([1], dtype="float64")
                x.stop_gradient = False
                a, b = Tanh.apply(x)
                b.sum().backward()
        """
        self.non_differentiable = args

    def set_materialize_grads(self, value: bool):
        """
        Sets whether to materialize output grad tensors. Default is True.

        This should be called only from inside the `forward` method.

        If True, undefined output grad tensors will be expanded to tensors full
        of zeros prior to calling the `backward` method.

        If False, undefined output grad tensors will be None.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer
                import numpy as np

                class Tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
205
                        return x+x+x, x+x
W
wanghuancoder 已提交
206 207 208 209 210 211 212 213 214 215

                    @staticmethod
                    def backward(ctx, grad, grad2):
                        assert np.equal(grad2.numpy(), paddle.zeros([1]).numpy())
                        return grad

                class Tanh2(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        ctx.set_materialize_grads(False)
216
                        return x+x+x, x+x
W
wanghuancoder 已提交
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233

                    @staticmethod
                    def backward(ctx, grad, grad2):
                        assert grad2==None
                        return grad

                x = paddle.ones([1], dtype="float64")
                x.stop_gradient = False
                Tanh.apply(x)[0].backward()

                x2 = paddle.ones([1], dtype="float64")
                x2.stop_gradient = False
                Tanh2.apply(x2)[0].backward()
        """
        self.materialize_grads = value


234
class PyLayerBackward(core.eager.PyLayer, PyLayerContext):
W
wanghuancoder 已提交
235 236 237 238
    def backward(self, *args):
        return self._forward_cls.backward(self, *args)


239
class PyLayerMeta(type):
W
wanghuancoder 已提交
240
    def __init__(cls, name, bases, attrs):
241
        cls._backward_function = type(
242
            name + '_backward', (PyLayerBackward,), {"_forward_cls": cls}
243
        )
W
wanghuancoder 已提交
244

245
        return super().__init__(name, bases, attrs)
W
wanghuancoder 已提交
246 247


248
class PyLayer(with_mateclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)):
H
hg-1099255210 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    """
    Paddle implements Python custom operators on the PaddlePaddle framework by creating a subclass of
    ``PyLayer``, which must comply with the following rules:

    1. The subclass must contain static ``forward`` and ``backward`` functions, with the first argument being
    :ref:`api_paddle_autograd_PyLayerContext`. If a returned value in ``backward`` corresponds to a ``Tensor`` that
    requires gradients in ``forward``, the returned value must be a ``Tensor``.

    2. Except for the first argument, other arguments of ``backward`` are gradients of the output ``Tensors``
    of ``forward``. Therefore, the number of input ``Tensor`` in ``backward`` must be the same as the number
    of output ``Tensor`` in ``forward``. If you need to use input ``Tensor`` from ``forward`` in ``backward``,
    you can save these ``Tensors`` by inputting them into :ref:`api_paddle_autograd_PyLayerContext`'s
    ``save_for_backward`` method and use them in ``backward`` later.

    3. The output of ``backward`` can be ``Tensor`` or ``list/tuple(Tensor)``, which are gradients of the
    output ``Tensor`` of ``forward``. Therefore, the number of output ``Tensor`` in ``backward`` is the same
    as the number of input ``Tensor`` in ``forward``.

    After building the custom operator, apply it by running the ``apply`` method.

    """

W
wanghuancoder 已提交
271 272 273
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
H
hg-1099255210 已提交
274
        It is to be overloaded by subclasses. It must accept a object of :ref:`api_paddle_autograd_PyLayerContext` as
275
        the first argument, followed by any number of arguments (tensors or other types).
W
wanghuancoder 已提交
276 277 278 279 280 281 282 283
        `None` can not be included in the returned result.

        Args:
            *args(tuple): input of PyLayer.
            **kwargs(dict): input of PyLayer.

        Returns:
            tensors or other types : output of PyLayer.
284

W
wanghuancoder 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """
        raise NotImplementedError(
307 308
            "You must implement the forward function for PyLayer."
        )
W
wanghuancoder 已提交
309 310 311 312

    @staticmethod
    def backward(ctx, *args):
        """
313
        This is a function to calculate the gradient. It is to be overloaded by subclasses.
H
hg-1099255210 已提交
314 315 316
        It must accept a object of :ref:`api_paddle_autograd_PyLayerContext` as the first
        argument, and the rest arguments are the gradient of forward's output tensors.
        Output tensors of backward are the gradient of forward's input tensors.
W
wanghuancoder 已提交
317 318 319 320 321 322 323

        Args:
            *args(tuple): The gradient of forward's output tensor(s).
            **kwargs(dict): The gradient of forward's output tensor(s).

        Returns:
            Tensor or list of Tensors: The gradient of forward's input tensor(s).
324

W
wanghuancoder 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """

        raise NotImplementedError(
348 349
            "You must implement the backward function for PyLayer."
        )
W
wanghuancoder 已提交
350 351 352 353 354 355 356 357 358


def once_differentiable(backward):
    def wrapper(ctx, *args):
        with paddle.fluid.dygraph.no_grad():
            outputs = backward(ctx, *args)
        return outputs

    return wrapper