spectral_norm_hook.py 7.8 KB
Newer Older
W
wangna11BD 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from ..layer.conv import Conv1DTranspose, Conv2DTranspose, Conv3DTranspose
from ..layer.common import Linear
from .. import functional as F

20
__all__ = []
W
wangna11BD 已提交
21 22


23
def normal_(x, mean=0.0, std=1.0):
W
wangna11BD 已提交
24 25 26 27 28
    temp_value = paddle.normal(mean, std, shape=x.shape)
    x.set_value(temp_value)
    return x


29
class SpectralNorm:
W
wangna11BD 已提交
30 31 32 33
    def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
        self.name = name
        self.dim = dim
        if n_power_iterations <= 0:
34 35
            raise ValueError(
                'Expected n_power_iterations to be positive, but '
36 37
                'got n_power_iterations={}'.format(n_power_iterations)
            )
W
wangna11BD 已提交
38 39 40 41 42 43 44
        self.n_power_iterations = n_power_iterations
        self.eps = eps

    def reshape_weight_to_matrix(self, weight):
        weight_mat = weight
        if self.dim != 0:
            # transpose dim to front
45
            weight_mat = weight_mat.transpose(
46 47 48
                [self.dim]
                + [d for d in range(weight_mat.dim()) if d != self.dim]
            )
W
wangna11BD 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

        height = weight_mat.shape[0]

        return weight_mat.reshape([height, -1])

    def compute_weight(self, layer, do_power_iteration):
        weight = getattr(layer, self.name + '_orig')
        u = getattr(layer, self.name + '_u')
        v = getattr(layer, self.name + '_v')
        weight_mat = self.reshape_weight_to_matrix(weight)

        if do_power_iteration:
            with paddle.no_grad():
                for _ in range(self.n_power_iterations):
                    v.set_value(
                        F.normalize(
65 66 67 68 69 70
                            paddle.matmul(
                                weight_mat,
                                u,
                                transpose_x=True,
                                transpose_y=False,
                            ),
W
wangna11BD 已提交
71
                            axis=0,
72
                            epsilon=self.eps,
73 74
                        )
                    )
W
wangna11BD 已提交
75 76 77 78 79

                    u.set_value(
                        F.normalize(
                            paddle.matmul(weight_mat, v),
                            axis=0,
80
                            epsilon=self.eps,
81 82
                        )
                    )
W
wangna11BD 已提交
83 84 85 86 87 88 89 90 91
                if self.n_power_iterations > 0:
                    u = u.clone()
                    v = v.clone()

        sigma = paddle.dot(u, paddle.mv(weight_mat, v))
        weight = weight / sigma
        return weight

    def __call__(self, layer, inputs):
92 93 94 95 96
        setattr(
            layer,
            self.name,
            self.compute_weight(layer, do_power_iteration=layer.training),
        )
W
wangna11BD 已提交
97 98 99 100 101

    @staticmethod
    def apply(layer, name, n_power_iterations, dim, eps):
        for k, hook in layer._forward_pre_hooks.items():
            if isinstance(hook, SpectralNorm) and hook.name == name:
102 103 104 105
                raise RuntimeError(
                    "Cannot register two spectral_norm hooks on "
                    "the same parameter {}".format(name)
                )
W
wangna11BD 已提交
106 107 108 109 110 111 112 113 114 115

        fn = SpectralNorm(name, n_power_iterations, dim, eps)
        weight = layer._parameters[name]

        with paddle.no_grad():
            weight_mat = fn.reshape_weight_to_matrix(weight)
            h, w = weight_mat.shape

            # randomly initialize u and v
            u = layer.create_parameter([h])
116
            u = normal_(u, 0.0, 1.0)
W
wangna11BD 已提交
117
            v = layer.create_parameter([w])
118
            v = normal_(v, 0.0, 1.0)
W
wangna11BD 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            u = F.normalize(u, axis=0, epsilon=fn.eps)
            v = F.normalize(v, axis=0, epsilon=fn.eps)

        # delete fn.name form parameters, otherwise you can not set attribute
        del layer._parameters[fn.name]
        layer.add_parameter(fn.name + "_orig", weight)
        # still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an Parameter and
        # gets added as a parameter. Instead, we register weight * 1.0 as a plain
        # attribute.
        setattr(layer, fn.name, weight * 1.0)
        layer.register_buffer(fn.name + "_u", u)
        layer.register_buffer(fn.name + "_v", v)
        layer.register_forward_pre_hook(fn)
        return fn


137 138 139
def spectral_norm(
    layer, name='weight', n_power_iterations=1, eps=1e-12, dim=None
):
W
wangna11BD 已提交
140
    r"""
141
    Applies spectral normalization to a parameter according to the
W
wangna11BD 已提交
142 143 144 145 146 147 148 149
    following Calculation:

    Step 1:
    Generate vector U in shape of [H], and V in shape of [W].
    While H is the :attr:`dim` th dimension of the input weights,
    and W is the product result of remaining dimensions.

    Step 2:
150
    :attr:`n_power_iterations` should be a positive integer, do following
W
wangna11BD 已提交
151 152 153 154
    calculations with U and V for :attr:`power_iters` rounds.

    .. math::

155
        \mathbf{v} := \frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
W
wangna11BD 已提交
156

157
        \mathbf{u} := \frac{\mathbf{W} \mathbf{v}}{\|\mathbf{W} \mathbf{v}\|_2}
W
wangna11BD 已提交
158 159 160 161 162 163 164 165

    Step 3:
    Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.

    .. math::

        \sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}

166
        \mathbf{W} = \frac{\mathbf{W}}{\sigma(\mathbf{W})}
W
wangna11BD 已提交
167 168 169 170 171 172 173 174 175 176


    Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .

    Parameters:
        layer(Layer): Layer of paddle, which has weight.
        name(str, optional): Name of the weight parameter. Default: 'weight'.
        n_power_iterations(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
        eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
        dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: None.
177

W
wangna11BD 已提交
178
    Returns:
179
        Layer, the original layer with the spectral norm hook.
W
wangna11BD 已提交
180 181 182 183 184

    Examples:
       .. code-block:: python

            from paddle.nn import Conv2D
W
wangna11BD 已提交
185
            from paddle.nn.utils import spectral_norm
W
wangna11BD 已提交
186 187 188 189 190 191 192 193 194 195

            conv = Conv2D(3, 1, 3)
            sn_conv = spectral_norm(conv)
            print(sn_conv)
            # Conv2D(3, 1, kernel_size=[3, 3], data_format=NCHW)
            print(sn_conv.weight)
            # Tensor(shape=[1, 3, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
            #        [[[[-0.21090528,  0.18563725, -0.14127982],
            #           [-0.02310637,  0.03197737,  0.34353802],
            #           [-0.17117859,  0.33152047, -0.28408015]],
196
            #
W
wangna11BD 已提交
197 198 199
            #          [[-0.13336606, -0.01862637,  0.06959272],
            #           [-0.02236020, -0.27091628, -0.24532901],
            #           [ 0.27254242,  0.15516677,  0.09036587]],
200
            #
W
wangna11BD 已提交
201 202 203 204 205 206 207
            #          [[ 0.30169338, -0.28146112, -0.11768346],
            #           [-0.45765871, -0.12504843, -0.17482486],
            #           [-0.36866254, -0.19969313,  0.08783543]]]])

    """

    if dim is None:
208
        if isinstance(
209 210
            layer, (Conv1DTranspose, Conv2DTranspose, Conv3DTranspose, Linear)
        ):
W
wangna11BD 已提交
211 212 213 214 215
            dim = 1
        else:
            dim = 0
    SpectralNorm.apply(layer, name, n_power_iterations, dim, eps)
    return layer