py_layer.py 11.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.framework import dygraph_only
from paddle.fluid import core
__all__ = ['PyLayer', 'PyLayerContext']


class PyLayerContext(object):
    """
    The object of this class is a context that is used in PyLayer to enhance the function.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.autograd import PyLayer

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    # ctx is a object of PyLayerContext.
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    # ctx is a object of PyLayerContext.
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad
    """

    def __init__(self):
        self.container = None

    def save_for_backward(self, *tensors):
        """
        Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
        
        .. note::
            This API should be called at most once, and only inside `forward`. 

        Args:
            tensors(list of Tensors): Tensors to be stored.

        Returns:
            None
        
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad

        """
        self.container = tensors

    def saved_tensor(self):
        """
        Get the tensors stored by ``save_for_backward``.

        Returns:
            list of Tensors or None: If context contains tensors stored by `save_for_backward`, 
            then return these tensors, otherwise return None.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        # ctx is a context object that store some objects for backward.
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """

        return self.container


def with_mateclass(meta, *bases):
    class impl(meta):
        def __new__(cls, name, temp_bases, attrs):
            return meta(name, bases, attrs)

    return type.__new__(impl, "impl", (), {})


class CPyLayer(object):
    @classmethod
    @dygraph_only
    def apply(cls, *args, **kwargs):
        """
        After building the custom PyLayer, run it through the ``apply``.

        Args:
            *args(tuple): input of PyLayer.
            **kwargs(dict): input of PyLayer.

        Returns:
            tensors or other types : output of PyLayer.
        
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x, func1, func2=paddle.square):
                        ctx.func = func2
                        y = func1(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - ctx.func(y))
                        return grad


                data = paddle.randn([2, 3], dtype="float64")
                data.stop_gradient = False
                # run custom Layer.
                z = cus_tanh.apply(data, func1=paddle.tanh)
        """
        place = paddle.fluid.framework._current_expected_place()
        with paddle.fluid.dygraph.no_grad():
            return core.pylayer_apply(place, cls, *args, **kwargs)


class PyLayerBackward(PyLayerContext):
    def backward(self, *args, **kwargs):
        with paddle.fluid.dygraph.no_grad():
            return self._forward_cls.backward(*args, **kwargs)


class LayerMeta(type):
    def __init__(cls, name, bases, attrs):
        cls._backward_function = type(name + '_backward', (PyLayerBackward, ),
                                      {"_forward_cls": cls})

        return super(LayerMeta, cls).__init__(name, bases, attrs)


class PyLayer(with_mateclass(LayerMeta, CPyLayer)):
    """
    Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules:
    1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod.
    Their first argument should be a context and `None` can not be included in the returned result.
    2. Input of backward contains a context as the first argument, and the rest arguments are the 
    gradient of forward's output tensors. so the number of backward's input tensors equal to 
    the number of forward output tensors. If you need the forward's inputs or outputs in `backward`, 
    you can use `save_for_backward` to store the required tensors, and then use them in the backward.
    3. Output of backward function can only be `Tensor` or tuple/list of `Tensor`.
    Output tensors of backward are the gradient of forward's input tensors, 
    so the number of backward's output tensors equal to the number of forward input tensors.
    After building the custom Layer, run it through the `apply` method.
    

    Examples:
        .. code-block:: python

            import paddle
            from paddle.autograd import PyLayer

            # Inherit from PyLayer
            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x, func1, func2=paddle.square):
                    # ctx is a context object that store some objects for backward.
                    ctx.func = func2
                    y = func1(x)
                    # Pass tensors to backward.
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                # forward has only one output, so there is only one gradient in the input of backward.
                def backward(ctx, dy):
                    # Get the tensors passed by forward.
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - ctx.func(y))
                    # forward has only one input, so only one gradient tensor is returned.
                    return grad


            data = paddle.randn([2, 3], dtype="float64")
            data.stop_gradient = False
            z = cus_tanh.apply(data, func1=paddle.tanh)
            z.mean().backward()

            print(data.grad)

    """

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as 
        the first argument, followed by any number of arguments (tensors or other types). 
        `None` can not be included in the returned result.

        Args:
            *args(tuple): input of PyLayer.
            **kwargs(dict): input of PyLayer.

        Returns:
            tensors or other types : output of PyLayer.
        
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """
        raise NotImplementedError(
            "You must implement the forward function for PyLayer.")

    @staticmethod
    def backward(ctx, *args, **kwargs):
        """
        This is a function to calculate the gradient. It is to be overloaded by subclasses. 
        It must accept a object of `PyLayerContext` as the first argument, and the rest 
        arguments are the gradient of forward's output tensors. Output tensors of backward 
        are the gradient of forward's input tensors.

        Args:
            *args(tuple): The gradient of forward's output tensor(s).
            **kwargs(dict): The gradient of forward's output tensor(s).

        Returns:
            Tensor or list of Tensors: The gradient of forward's input tensor(s).
        
        Examples:
            .. code-block:: python

                import paddle
                from paddle.autograd import PyLayer

                class cus_tanh(PyLayer):
                    @staticmethod
                    def forward(ctx, x):
                        y = paddle.tanh(x)
                        # Pass tensors to backward.
                        ctx.save_for_backward(y)
                        return y

                    @staticmethod
                    def backward(ctx, dy):
                        # Get the tensors passed by forward.
                        y, = ctx.saved_tensor()
                        grad = dy * (1 - paddle.square(y))
                        return grad
        """

        raise NotImplementedError(
            "You must implement the backward function for PyLayer.")