spectral_norm_hook.py 7.8 KB
Newer Older
W
wangna11BD 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy as np

import paddle
from ..layer.conv import Conv1DTranspose, Conv2DTranspose, Conv3DTranspose
from ..layer.common import Linear
from .. import functional as F

23
__all__ = []
W
wangna11BD 已提交
24 25 26 27 28 29 30 31 32


def normal_(x, mean=0., std=1.):
    temp_value = paddle.normal(mean, std, shape=x.shape)
    x.set_value(temp_value)
    return x


class SpectralNorm(object):
33

W
wangna11BD 已提交
34 35 36 37
    def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
        self.name = name
        self.dim = dim
        if n_power_iterations <= 0:
38 39 40
            raise ValueError(
                'Expected n_power_iterations to be positive, but '
                'got n_power_iterations={}'.format(n_power_iterations))
W
wangna11BD 已提交
41 42 43 44 45 46 47
        self.n_power_iterations = n_power_iterations
        self.eps = eps

    def reshape_weight_to_matrix(self, weight):
        weight_mat = weight
        if self.dim != 0:
            # transpose dim to front
48 49 50
            weight_mat = weight_mat.transpose(
                [self.dim] +
                [d for d in range(weight_mat.dim()) if d != self.dim])
W
wangna11BD 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

        height = weight_mat.shape[0]

        return weight_mat.reshape([height, -1])

    def compute_weight(self, layer, do_power_iteration):
        weight = getattr(layer, self.name + '_orig')
        u = getattr(layer, self.name + '_u')
        v = getattr(layer, self.name + '_v')
        weight_mat = self.reshape_weight_to_matrix(weight)

        if do_power_iteration:
            with paddle.no_grad():
                for _ in range(self.n_power_iterations):
                    v.set_value(
                        F.normalize(
67 68 69 70
                            paddle.matmul(weight_mat,
                                          u,
                                          transpose_x=True,
                                          transpose_y=False),
W
wangna11BD 已提交
71
                            axis=0,
72 73
                            epsilon=self.eps,
                        ))
W
wangna11BD 已提交
74 75 76 77 78

                    u.set_value(
                        F.normalize(
                            paddle.matmul(weight_mat, v),
                            axis=0,
79 80
                            epsilon=self.eps,
                        ))
W
wangna11BD 已提交
81 82 83 84 85 86 87 88 89
                if self.n_power_iterations > 0:
                    u = u.clone()
                    v = v.clone()

        sigma = paddle.dot(u, paddle.mv(weight_mat, v))
        weight = weight / sigma
        return weight

    def __call__(self, layer, inputs):
90 91
        setattr(layer, self.name,
                self.compute_weight(layer, do_power_iteration=layer.training))
W
wangna11BD 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

    @staticmethod
    def apply(layer, name, n_power_iterations, dim, eps):
        for k, hook in layer._forward_pre_hooks.items():
            if isinstance(hook, SpectralNorm) and hook.name == name:
                raise RuntimeError("Cannot register two spectral_norm hooks on "
                                   "the same parameter {}".format(name))

        fn = SpectralNorm(name, n_power_iterations, dim, eps)
        weight = layer._parameters[name]

        with paddle.no_grad():
            weight_mat = fn.reshape_weight_to_matrix(weight)
            h, w = weight_mat.shape

            # randomly initialize u and v
            u = layer.create_parameter([h])
            u = normal_(u, 0., 1.)
            v = layer.create_parameter([w])
            v = normal_(v, 0., 1.)
            u = F.normalize(u, axis=0, epsilon=fn.eps)
            v = F.normalize(v, axis=0, epsilon=fn.eps)

        # delete fn.name form parameters, otherwise you can not set attribute
        del layer._parameters[fn.name]
        layer.add_parameter(fn.name + "_orig", weight)
        # still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an Parameter and
        # gets added as a parameter. Instead, we register weight * 1.0 as a plain
        # attribute.
        setattr(layer, fn.name, weight * 1.0)
        layer.register_buffer(fn.name + "_u", u)
        layer.register_buffer(fn.name + "_v", v)
        layer.register_forward_pre_hook(fn)
        return fn


def spectral_norm(layer,
                  name='weight',
                  n_power_iterations=1,
                  eps=1e-12,
                  dim=None):
    r"""
    This spectral_norm layer applies spectral normalization to a parameter according to the 
    following Calculation:

    Step 1:
    Generate vector U in shape of [H], and V in shape of [W].
    While H is the :attr:`dim` th dimension of the input weights,
    and W is the product result of remaining dimensions.

    Step 2:
145
    :attr:`n_power_iterations` should be a positive integer, do following
W
wangna11BD 已提交
146 147 148 149
    calculations with U and V for :attr:`power_iters` rounds.

    .. math::

150
        \mathbf{v} := \frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
W
wangna11BD 已提交
151

152
        \mathbf{u} := \frac{\mathbf{W} \mathbf{v}}{\|\mathbf{W} \mathbf{v}\|_2}
W
wangna11BD 已提交
153 154 155 156 157 158 159 160

    Step 3:
    Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.

    .. math::

        \sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}

161
        \mathbf{W} = \frac{\mathbf{W}}{\sigma(\mathbf{W})}
W
wangna11BD 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179


    Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .

    Parameters:
        layer(Layer): Layer of paddle, which has weight.
        name(str, optional): Name of the weight parameter. Default: 'weight'.
        n_power_iterations(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
        eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
        dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: None.
        
    Returns:
        The original layer with the spectral norm hook

    Examples:
       .. code-block:: python

            from paddle.nn import Conv2D
W
wangna11BD 已提交
180
            from paddle.nn.utils import spectral_norm
W
wangna11BD 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202

            conv = Conv2D(3, 1, 3)
            sn_conv = spectral_norm(conv)
            print(sn_conv)
            # Conv2D(3, 1, kernel_size=[3, 3], data_format=NCHW)
            print(sn_conv.weight)
            # Tensor(shape=[1, 3, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
            #        [[[[-0.21090528,  0.18563725, -0.14127982],
            #           [-0.02310637,  0.03197737,  0.34353802],
            #           [-0.17117859,  0.33152047, -0.28408015]],
            # 
            #          [[-0.13336606, -0.01862637,  0.06959272],
            #           [-0.02236020, -0.27091628, -0.24532901],
            #           [ 0.27254242,  0.15516677,  0.09036587]],
            # 
            #          [[ 0.30169338, -0.28146112, -0.11768346],
            #           [-0.45765871, -0.12504843, -0.17482486],
            #           [-0.36866254, -0.19969313,  0.08783543]]]])

    """

    if dim is None:
203 204 205
        if isinstance(
                layer,
            (Conv1DTranspose, Conv2DTranspose, Conv3DTranspose, Linear)):
W
wangna11BD 已提交
206 207 208 209 210
            dim = 1
        else:
            dim = 0
    SpectralNorm.apply(layer, name, n_power_iterations, dim, eps)
    return layer