initializer.py 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import math

import numpy as np

from ...fluid.framework import default_main_program, in_dygraph_mode
D
Difer 已提交
21
from .lazy_init import lazy_init_helper
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

__all__ = []


class Initializer:
    """Base class for parameter initializers

    Defines the common interface of parameter initializers.
    They add operations to the init program that are used
    to initialize parameter. Users should not use this class
    directly, but need to use one of its implementations.
    """

    def __init__(self):
        pass

    def __call__(self, param, block=None):
        if not lazy_init_helper().state:
            return self.forward(param, block)

        return self._lazy_init(param, block)

    def forward(self, param, block=None):
D
Difer 已提交
45
        """Add corresponding initialization operations to the network."""
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        raise NotImplementedError()

    def _lazy_init(self, param, block=None):
        """
        Apply lazy initialization
        """
        assert in_dygraph_mode()

        def init_op_creator(forward, param, block):
            new_var = param._to_static_var(True, block=block)
            # Record initializer operator
            with lazy_init_helper():
                forward(new_var, block)

        # Add hook function for initializing param in dygraph mode
61
        param.set_init_func(functools.partial(self.forward))
62
        param._init_op_creator = functools.partial(
63
            init_op_creator, self.forward
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        )

        return param

    def _check_block(self, block):
        if block is None:
            block = default_main_program().global_block()

        return block

    def _compute_fans(self, var):
        """Compute the fan_in and the fan_out for layers

        This method computes the fan_in and the fan_out
        for neural network layers, if not specified. It is
        not possible to perfectly estimate fan_in and fan_out.
        This method will estimate it correctly for matrix multiply and
        convolutions.

        Args:
84
            var: variable for which fan_in and fan_out have to be computed.
85 86

        Returns:
87
            tuple of two integers (fan_in, fan_out).
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        """
        shape = var.shape
        if not shape or len(shape) == 0:
            fan_in = fan_out = 1
        elif len(shape) == 1:
            fan_in = fan_out = shape[0]
        elif len(shape) == 2:
            # This is the case for simple matrix multiply
            fan_in = shape[0]
            fan_out = shape[1]
        else:
            # Assume this to be a convolutional kernel
            # In PaddlePaddle, the shape of the kernel is like:
            # [num_filters, num_filter_channels, ...] where the remaining
            # dimensions are the filter_size
            receptive_field_size = np.prod(shape[2:])
            fan_in = shape[1] * receptive_field_size
            fan_out = shape[0] * receptive_field_size

        return (fan_in, fan_out)


def calculate_gain(nonlinearity, param=None):
    """
    Get the recommended ``gain`` value of some nonlinearity function. ``gain`` value can be used in some
    ``paddle.nn.initializer`` api to adjust the initialization value.

    Args:
        nonlinearity(str): name of nonlinearity activation function. If it is a linear function, such as:
            `linear/conv1d/conv2d/conv3d/conv1d_transpose/conv2d_transpose/conv3d_transpose` , 1.0 will be returned.
        param(bool|int|float, optional): optional parameter for somme nonlinearity function. Now, it only applies to
            'leaky_relu'. Default: None, it will be calculated as 0.01 in the formula.

    Returns:
        A float value, which is the recommended gain for this nonlinearity function.

    Examples:
        .. code-block:: python

127 128 129 130 131 132 133 134 135 136 137
            >>> import paddle

            >>> gain = paddle.nn.initializer.calculate_gain('tanh')
            >>> print(gain)
            1.6666666666666667
            >>> # 5.0 / 3
            >>> gain = paddle.nn.initializer.calculate_gain('leaky_relu', param=1.0)
            >>> print(gain)
            1.0
            >>> # math.sqrt(2.0 / (1+param^2))
            >>> initializer = paddle.nn.initializer.Orthogonal(gain)
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

    """
    if param is None:
        param = 0.01
    else:
        assert isinstance(param, (bool, int, float))
        param = float(param)
    recommended_gain = {
        'sigmoid': 1,
        'linear': 1,
        'conv1d': 1,
        'conv2d': 1,
        'conv3d': 1,
        'conv1d_transpose': 1,
        'conv2d_transpose': 1,
        'conv3d_transpose': 1,
        'tanh': 5.0 / 3,
        'relu': math.sqrt(2.0),
        'leaky_relu': math.sqrt(2.0 / (1 + param**2)),
        'selu': 3.0 / 4,
    }
    if nonlinearity in recommended_gain.keys():
        return recommended_gain[nonlinearity]
    else:
        raise ValueError(
            "nonlinearity function {} is not suppported now.".format(
                nonlinearity
            )
        )