[论文复现] GAN中的SpertralNorm Op用法
Created by: XDang13
class SpertralNormLinear(fluid.dygraph.Layer):
def __init__(self, input_dim, dim, bias_attr=False):
super(SpertralNormLinear, self).__init__()
self.linear = Linear(input_dim, dim, bias_attr=bias_attr)
self.sp_block = fluid.dygraph.nn.SpectralNorm(self.linear.weight.shape, dim=0, power_iters=2)
def forward(self, input):
self.spectral_norm()
output = self.linear(input)
return output
def spectral_norm(self):
weight = self.linear.weight
norm_weight = self.sp_block(weight)
self.linear.weight.set_value(norm_weight)
还没在GAN中测试,理论上应该是可以的。