diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 8e083fa05c75dddc6dcc6e33a7c73b614a8da46d..8d486341b3447e85e1f34d3050cd66fe35ba2c09 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -205,6 +205,8 @@ class SpectralNormGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Weight", this->Input("Weight")); op->SetInput("U", this->Input("U")); op->SetInput("V", this->Input("V")); + op->SetInput("UOut", this->Output("UOut")); + op->SetInput("VOut", this->Output("VOut")); op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight")); diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h index 15d637a495b5650dd98580019fa8805147c7d6de..fa334942b13348db0679159a4c23f08c34742cd6 100644 --- a/paddle/fluid/operators/spectral_norm_op.h +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -64,6 +64,7 @@ template static inline void UpdateUandV( Tensor* u, Tensor* v, Tensor* weight, const int power_iters, const float eps, const framework::ExecutionContext& ctx) { + if (power_iters <= 0) return; auto& place = *ctx.template device_context().eigen_device(); auto blas = math::GetBlas(ctx); auto u_t = EigenTensor::From(*u); @@ -92,8 +93,7 @@ static inline void UpdateUandV( template static inline void CalcMatrixSigmaAndNormWeight( Tensor* sigma, const Tensor* u, const Tensor* v, - Tensor* weight, const int power_iters, - const float eps, const framework::ExecutionContext& ctx) { + Tensor* weight, const framework::ExecutionContext& ctx) { auto& place = *ctx.template device_context().eigen_device(); auto blas = math::GetBlas(ctx); auto sigma_t = EigenTensor::From(*sigma); @@ -168,7 +168,7 @@ class SpectralNormKernel : public framework::OpKernel { power_iters, eps, ctx); CalcMatrixSigmaAndNormWeight( &sigma, &(u_out->Resize({h, 1})), &(v_out->Resize({w, 1})), &weight_mat, - power_iters, eps, ctx); + ctx); if (dim != 0) { std::vector perm; @@ -205,8 +205,6 @@ class SpectralNormGradKernel : public framework::OpKernel { auto weight_grad = ctx.Output(framework::GradVarName("Weight")); int dim = ctx.Attr("dim"); - int power_iters = ctx.Attr("power_iters"); - float eps = ctx.Attr("eps"); const int h = u_out->dims()[0]; const int w = v_out->dims()[0]; @@ -251,7 +249,7 @@ class SpectralNormGradKernel : public framework::OpKernel { CalcMatrixSigmaAndNormWeight( &sigma, &(u_mat.Resize({h, 1})), &(v_mat.Resize({w, 1})), &weight_mat, - power_iters, eps, ctx); + ctx); Tensor uv; uv.mutable_data({h, w}, ctx.GetPlace()); diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index a14c3a81c121758ed90450cd5eb5990f3f7739e1..fc76dbc216ff82c8555f910b5fd1b1a6f3f4eb17 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -3031,6 +3031,7 @@ class SpectralNorm(layers.Layer): dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0. power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1. eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12. + fix_state(bool, optional): whether to update the two vectors `u` and `v`. Default: True. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". @@ -3055,10 +3056,12 @@ class SpectralNorm(layers.Layer): dim=0, power_iters=1, eps=1e-12, + fix_state=True, dtype='float32'): super(SpectralNorm, self).__init__() self._power_iters = power_iters self._eps = eps + self._fix_state = fix_state self._dim = dim self._dtype = dtype @@ -3080,10 +3083,31 @@ class SpectralNorm(layers.Layer): default_initializer=Normal(0., 1.)) self.weight_v.stop_gradient = True + if fix_state: + self.out_weight_u = self.create_parameter( + attr=ParamAttr(), + shape=[h], + dtype=self._dtype, + default_initializer=Normal(0., 1.)) + self.out_weight_u.stop_gradient = True + + self.out_weight_v = self.create_parameter( + attr=ParamAttr(), + shape=[w], + dtype=self._dtype, + default_initializer=Normal(0., 1.)) + self.out_weight_v.stop_gradient = True + else: + self.out_weight_u = self.weight_u + self.out_weight_v = self.weight_v + def forward(self, weight): check_variable_and_dtype(weight, "weight", ['float32', 'float64'], 'SpectralNorm') - inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} + inputs = { + 'Weight': weight, 'U': self.weight_u, 'V': self.weight_v, + 'UOut': self.out_weight_u, 'VOut': self.out_weight_v, + } out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type="spectral_norm", diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9313de8c64fcf4efc1e192ad2826f05f51869bbf..61a6b0ad9948b104333ef5a15e08e850bcaa414e 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3720,11 +3720,13 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): # create output out = helper.create_variable(dtype=dtype) + u_out = helper.create_variable(dtype=dtype) + v_out = helper.create_variable(dtype=dtype) helper.append_op( type="spectral_norm", inputs=inputs, - outputs={"Out": out, }, + outputs={"Out": out, "UOut": u_out, "VOut": v_out}, attrs={ "dim": dim, "power_iters": power_iters, diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index 227ed1d2ff02c836572df7da2ede9f4131d01f0d..c080583cec0aa4da0e0c5abe093f4bb284bf0dd3 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -132,6 +132,40 @@ class TestSpectralNormOp2(TestSpectralNormOp): self.fix_state = True +class TestSpectralNormOpFixState(TestSpectralNormOpNoGrad): + def test_check_grad_ignore_uv(self): + self.check_grad( + ['Weight'], + 'Out', + no_grad_set=set(["U", "V"]), ) + + def initTestCase(self): + self.weight_shape = (10, 12) + self.u_shape = (10, ) + self.v_shape = (12, ) + self.dim = 0 + self.power_iters = 3 + self.eps = 1e-12 + self.fix_state = False + + +class TestSpectralNormOpUpdateState(TestSpectralNormOpNoGrad): + def test_check_grad_ignore_uv(self): + self.check_grad( + ['Weight'], + 'Out', + no_grad_set=set(["U", "V"]), ) + + def initTestCase(self): + self.weight_shape = (10, 12) + self.u_shape = (10, ) + self.v_shape = (12, ) + self.dim = 0 + self.power_iters = 3 + self.eps = 1e-12 + self.fix_state = True + + class TestSpectralNormOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()):