py_layer.py 21.9 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.framework import dygraph_only
L
Leo Chen 已提交
17 18
from paddle.fluid.dygraph.amp.auto_cast import amp_state
from paddle.amp.auto_cast import auto_cast
19
from paddle.fluid import core
20

21
__all__ = []
22 23


24
class LegacyPyLayerContext:
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    """
    The object of this class is a context that is used in PyLayer to enhance the function.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.autograd import PyLayer

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    # ctx is a object of PyLayerContext.
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    # ctx is a object of PyLayerContext.
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad
    """

    def __init__(self):
        self.container = None
L
Leo Chen 已提交
52
        self._amp_state = amp_state()
53 54 55 56

    def save_for_backward(self, *tensors):
        """
        Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
57

58
        Note:
59
            This API should be called at most once, and only inside `forward`.
60 61 62 63 64 65

        Args:
            tensors(list of Tensors): Tensors to be stored.

        Returns:
            None
66

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad

        """
        self.container = tensors

    def saved_tensor(self):
        """
        Get the tensors stored by ``save_for_backward``.

        Returns:
97
            list of Tensors or None: If context contains tensors stored by `save_for_backward`,
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
            then return these tensors, otherwise return None.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """

        return self.container


def with_mateclass(meta, *bases):
    class impl(meta):
        def __new__(cls, name, temp_bases, attrs):
            return meta(name, bases, attrs)

    return type.__new__(impl, "impl", (), {})


134
class CPyLayer:
135 136 137 138 139 140 141 142 143 144 145 146
    @classmethod
    @dygraph_only
    def apply(cls, *args, **kwargs):
        """
        After building the custom PyLayer, run it through the ``apply``.

        Args:
            *args(tuple): input of PyLayer.
            **kwargs(dict): input of PyLayer.

        Returns:
            tensors or other types : output of PyLayer.
147

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x, func1, func2=paddle.square):
                        ctx.func = func2
                        y = func1(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - ctx.func(y))
                        return grad


                data = paddle.randn([2, 3], dtype="float64")
                data.stop_gradient = False
                # run custom Layer.
                z = cus_tanh.apply(data, func1=paddle.tanh)
        """
        place = paddle.fluid.framework._current_expected_place()
        with paddle.fluid.dygraph.no_grad():
            return core.pylayer_apply(place, cls, *args, **kwargs)


181
class PyLayerBackward(LegacyPyLayerContext):
182
    def backward(self, *args, **kwargs):
183 184
        with paddle.fluid.dygraph.guard():
            with paddle.fluid.dygraph.no_grad():
185 186 187 188 189
                if (
                    self._amp_state
                    and 'enable' in self._amp_state
                    and self._amp_state['enable']
                ):
L
Leo Chen 已提交
190 191 192 193 194
                    with auto_cast(**args[0]._amp_state):
                        return self._forward_cls.backward(*args, **kwargs)
                else:

                    return self._forward_cls.backward(*args, **kwargs)
195
                return self._forward_cls.backward(*args, **kwargs)
196 197 198 199


class LayerMeta(type):
    def __init__(cls, name, bases, attrs):
200 201 202
        cls._backward_function = type(
            name + '_backward', (PyLayerBackward,), {"_forward_cls": cls}
        )
203

204
        return super().__init__(name, bases, attrs)
205 206


207
class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)):
208 209 210 211
    """
    Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules:
    1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod.
    Their first argument should be a context and `None` can not be included in the returned result.
212 213 214
    2. Input of backward contains a context as the first argument, and the rest arguments are the
    gradient of forward's output tensors. so the number of backward's input tensors equal to
    the number of forward output tensors. If you need the forward's inputs or outputs in `backward`,
215 216
    you can use `save_for_backward` to store the required tensors, and then use them in the backward.
    3. Output of backward function can only be `Tensor` or tuple/list of `Tensor`.
217
    Output tensors of backward are the gradient of forward's input tensors,
218 219
    so the number of backward's output tensors equal to the number of forward input tensors.
    After building the custom Layer, run it through the `apply` method.
220

221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

    Examples:
        .. code-block:: python

            import paddle
            from paddle.autograd import PyLayer

            # Inherit from PyLayer
            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x, func1, func2=paddle.square):
                    # ctx is a context object that store some objects for backward.
                    ctx.func = func2
                    y = func1(x)
                    # Pass tensors to backward.
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                # forward has only one output, so there is only one gradient in the input of backward.
                def backward(ctx, dy):
                    # Get the tensors passed by forward.
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - ctx.func(y))
                    # forward has only one input, so only one gradient tensor is returned.
                    return grad


            data = paddle.randn([2, 3], dtype="float64")
            data.stop_gradient = False
            z = cus_tanh.apply(data, func1=paddle.tanh)
            z.mean().backward()

            print(data.grad)

    """

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
261 262
        It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as
        the first argument, followed by any number of arguments (tensors or other types).
263 264 265 266 267 268 269 270
        `None` can not be included in the returned result.

        Args:
            *args(tuple): input of PyLayer.
            **kwargs(dict): input of PyLayer.

        Returns:
            tensors or other types : output of PyLayer.
271

272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """
        raise NotImplementedError(
294 295
            "You must implement the forward function for PyLayer."
        )
296 297 298 299

    @staticmethod
    def backward(ctx, *args, **kwargs):
        """
300 301 302
        This is a function to calculate the gradient. It is to be overloaded by subclasses.
        It must accept a object of `PyLayerContext` as the first argument, and the rest
        arguments are the gradient of forward's output tensors. Output tensors of backward
303 304 305 306 307 308 309 310
        are the gradient of forward's input tensors.

        Args:
            *args(tuple): The gradient of forward's output tensor(s).
            **kwargs(dict): The gradient of forward's output tensor(s).

        Returns:
            Tensor or list of Tensors: The gradient of forward's input tensor(s).
311

312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """

        raise NotImplementedError(
335 336
            "You must implement the backward function for PyLayer."
        )
W
wanghuancoder 已提交
337 338


339
class EagerPyLayerContext:
W
wanghuancoder 已提交
340 341 342
    def save_for_backward(self, *tensors):
        """
        Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
343

344
        Note:
345
            This API should be called at most once, and only inside `forward`.
W
wanghuancoder 已提交
346 347 348 349 350 351

        Args:
            tensors(list of Tensors): Tensors to be stored.

        Returns:
            None
352

W
wanghuancoder 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad

        """
        self.container = tensors

    def saved_tensor(self):
        """
        Get the tensors stored by ``save_for_backward``.

        Returns:
383
            list of Tensors or None: If context contains tensors stored by `save_for_backward`,
W
wanghuancoder 已提交
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
            then return these tensors, otherwise return None.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """
        return self.container

410 411 412
    def mark_not_inplace(self, *args):
        """
        Marks inputs as not inplace.
413
        This should be called at most once, only from inside the `forward` method,
414 415
        and all arguments should be Tensor inputs.

416 417
        If the Tensor returned by `forward` method is the same as the Tensor input of forward,
        and this Tensor is marked as not_inplace, then Paddle will help the user create a new Tensor as output.
418 419 420 421 422 423 424 425 426 427 428 429
        Thereby preventing the auto grad information of the input Tensor from being overwritten.

        Examples:
            .. code-block:: python

                import paddle

                class Exp(paddle.autograd.PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        ctx.mark_not_inplace(x)
                        return x
430

431 432 433 434 435 436 437 438 439 440
                    @staticmethod
                    def backward(ctx, grad_output):
                        out = grad_output.exp()
                        return out

                x = paddle.randn((1, 1))
                x.stop_gradient = False
                attn_layers = []
                for idx in range(0, 2):
                    attn_layers.append(Exp())
441

442 443 444 445 446 447 448
                for step in range(0, 2):
                    a = x
                    for j in range(0,2):
                        a = attn_layers[j].apply(x)
                    a.backward()
        """
        self.not_inplace_tensors = args
W
wanghuancoder 已提交
449 450 451 452

    def mark_non_differentiable(self, *args):
        """
        Marks outputs as non-differentiable.
453
        This should be called at most once, only from inside the `forward` method,
W
wanghuancoder 已提交
454 455 456 457 458 459 460 461 462 463 464
        and all arguments should be tensor outputs.

        This will mark outputs as not requiring gradients, increasing the
        efficiency of backward computation. You still need to accept a gradient
        for each output in `backward`, but it's always going to
        be a zero tensor with the same shape as the shape of a corresponding
        output.

        Examples:
            .. code-block:: python

465 466
                import os
                os.environ['FLAGS_enable_eager_mode'] = '1'
W
wanghuancoder 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
                import paddle
                from paddle.autograd import PyLayer
                import numpy as np

                class Tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        a = x + x
                        b = x + x + x
                        ctx.mark_non_differentiable(a)
                        return a, b

                    @staticmethod
                    def backward(ctx, grad_a, grad_b):
                        assert np.equal(grad_a.numpy(), paddle.zeros([1]).numpy())
                        assert np.equal(grad_b.numpy(), paddle.ones([1], dtype="float64").numpy())
                        return grad_b

                x = paddle.ones([1], dtype="float64")
                x.stop_gradient = False
                a, b = Tanh.apply(x)
                b.sum().backward()
        """
        self.non_differentiable = args

    def set_materialize_grads(self, value: bool):
        """
        Sets whether to materialize output grad tensors. Default is True.

        This should be called only from inside the `forward` method.

        If True, undefined output grad tensors will be expanded to tensors full
        of zeros prior to calling the `backward` method.

        If False, undefined output grad tensors will be None.

        Examples:
            .. code-block:: python

506 507
                import os
                os.environ['FLAGS_enable_eager_mode'] = '1'
W
wanghuancoder 已提交
508 509 510 511 512 513 514
                import paddle
                from paddle.autograd import PyLayer
                import numpy as np

                class Tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
515
                        return x+x+x, x+x
W
wanghuancoder 已提交
516 517 518 519 520 521 522 523 524 525

                    @staticmethod
                    def backward(ctx, grad, grad2):
                        assert np.equal(grad2.numpy(), paddle.zeros([1]).numpy())
                        return grad

                class Tanh2(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        ctx.set_materialize_grads(False)
526
                        return x+x+x, x+x
W
wanghuancoder 已提交
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550

                    @staticmethod
                    def backward(ctx, grad, grad2):
                        assert grad2==None
                        return grad

                x = paddle.ones([1], dtype="float64")
                x.stop_gradient = False
                Tanh.apply(x)[0].backward()

                x2 = paddle.ones([1], dtype="float64")
                x2.stop_gradient = False
                Tanh2.apply(x2)[0].backward()
        """
        self.materialize_grads = value


class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext):
    def backward(self, *args):
        return self._forward_cls.backward(self, *args)


class EagerPyLayerMeta(type):
    def __init__(cls, name, bases, attrs):
551 552 553
        cls._backward_function = type(
            name + '_backward', (EagerPyLayerBackward,), {"_forward_cls": cls}
        )
W
wanghuancoder 已提交
554

555
        return super().__init__(name, bases, attrs)
W
wanghuancoder 已提交
556 557 558


class EagerPyLayer(
559 560
    with_mateclass(EagerPyLayerMeta, core.eager.PyLayer, EagerPyLayerContext)
):
W
wanghuancoder 已提交
561 562 563
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
564 565
        It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as
        the first argument, followed by any number of arguments (tensors or other types).
W
wanghuancoder 已提交
566 567 568 569 570 571 572 573
        `None` can not be included in the returned result.

        Args:
            *args(tuple): input of PyLayer.
            **kwargs(dict): input of PyLayer.

        Returns:
            tensors or other types : output of PyLayer.
574

W
wanghuancoder 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """
        raise NotImplementedError(
597 598
            "You must implement the forward function for PyLayer."
        )
W
wanghuancoder 已提交
599 600 601 602

    @staticmethod
    def backward(ctx, *args):
        """
603 604 605
        This is a function to calculate the gradient. It is to be overloaded by subclasses.
        It must accept a object of `PyLayerContext` as the first argument, and the rest
        arguments are the gradient of forward's output tensors. Output tensors of backward
W
wanghuancoder 已提交
606 607 608 609 610 611 612 613
        are the gradient of forward's input tensors.

        Args:
            *args(tuple): The gradient of forward's output tensor(s).
            **kwargs(dict): The gradient of forward's output tensor(s).

        Returns:
            Tensor or list of Tensors: The gradient of forward's input tensor(s).
614

W
wanghuancoder 已提交
615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """

        raise NotImplementedError(
638 639
            "You must implement the backward function for PyLayer."
        )
W
wanghuancoder 已提交
640 641 642 643 644 645 646 647 648


def once_differentiable(backward):
    def wrapper(ctx, *args):
        with paddle.fluid.dygraph.no_grad():
            outputs = backward(ctx, *args)
        return outputs

    return wrapper