nn.py 1.9 KB
Newer Older
L
LielinJiang 已提交
1
import paddle
L
fix nan  
LielinJiang 已提交
2
import paddle.nn as nn
L
LielinJiang 已提交
3 4


L
fix nan  
LielinJiang 已提交
5
class _SpectralNorm(nn.SpectralNorm):
L
LielinJiang 已提交
6 7 8 9 10 11
    def __init__(self,
                 weight_shape,
                 dim=0,
                 power_iters=1,
                 eps=1e-12,
                 dtype='float32'):
L
fix nan  
LielinJiang 已提交
12 13
        super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps,
                                            dtype)
L
LielinJiang 已提交
14 15 16 17 18

    def forward(self, weight):
        inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
        out = self._helper.create_variable_for_type_inference(self._dtype)
        _power_iters = self._power_iters if self.training else 0
L
fix nan  
LielinJiang 已提交
19 20 21 22 23 24 25 26 27 28
        self._helper.append_op(type="spectral_norm",
                               inputs=inputs,
                               outputs={
                                   "Out": out,
                               },
                               attrs={
                                   "dim": self._dim,
                                   "power_iters": _power_iters,
                                   "eps": self._eps,
                               })
L
LielinJiang 已提交
29 30 31 32 33

        return out


class Spectralnorm(paddle.nn.Layer):
L
fix nan  
LielinJiang 已提交
34
    def __init__(self, layer, dim=0, power_iters=1, eps=1e-12, dtype='float32'):
L
LielinJiang 已提交
35
        super(Spectralnorm, self).__init__()
L
fix nan  
LielinJiang 已提交
36 37
        self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters,
                                           eps, dtype)
L
LielinJiang 已提交
38 39 40 41 42 43
        self.dim = dim
        self.power_iters = power_iters
        self.eps = eps
        self.layer = layer
        weight = layer._parameters['weight']
        del layer._parameters['weight']
L
fix nan  
LielinJiang 已提交
44 45
        self.weight_orig = self.create_parameter(weight.shape,
                                                 dtype=weight.dtype)
L
LielinJiang 已提交
46 47 48 49 50 51 52
        self.weight_orig.set_value(weight)

    def forward(self, x):
        weight = self.spectral_norm(self.weight_orig)
        self.layer.weight = weight
        out = self.layer(x)
        return out