diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 71e5c978d79dc5736c8813b9b2735f9d7bc4c522..8e083fa05c75dddc6dcc6e33a7c73b614a8da46d 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -30,6 +30,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm"); OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); + OP_INOUT_CHECK(ctx->HasOutput("UOut"), "Output", "UOut", "SpectralNorm"); + OP_INOUT_CHECK(ctx->HasOutput("VOut"), "Output", "VOut", "SpectralNorm"); auto dim_weight = ctx->GetInputDim("Weight"); auto rank_weight = dim_weight.size(); @@ -88,6 +90,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", dim_weight); + ctx->SetOutputDim("UOut", dim_u); + ctx->SetOutputDim("VOut", dim_v); ctx->ShareLoD("Weight", /*->*/ "Out"); } @@ -126,6 +130,10 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output weight tensor of spectral_norm operator, " "This tensor is in same shape with Input(Weight)."); + AddOutput("UOut", + "The updated value of `U`"); + AddOutput("VOut", + "The updated value of `V`"); AddAttr("dim", "The index of dimension which should be permuted " @@ -145,7 +153,6 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { "the denominator to aviod divide zero. " "Default 1e-12.") .SetDefault(1e-12); - AddComment(R"DOC( This layer calculates the spectral normalization value of weight of fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h index 954edc796914c84fc81898c581cb75ff6b16468d..15d637a495b5650dd98580019fa8805147c7d6de 100644 --- a/paddle/fluid/operators/spectral_norm_op.h +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -61,13 +61,11 @@ static inline void TransCompute(const int rank, const Tensor& in, Tensor* out, } template -static inline void CalcMatrixSigmaAndNormWeight( - Tensor* sigma, Tensor* u, Tensor* v, Tensor* weight, const int power_iters, +static inline void UpdateUandV( + Tensor* u, Tensor* v, Tensor* weight, const int power_iters, const float eps, const framework::ExecutionContext& ctx) { auto& place = *ctx.template device_context().eigen_device(); auto blas = math::GetBlas(ctx); - auto sigma_t = EigenTensor::From(*sigma); - auto weight_t = EigenTensor::From(*weight); auto u_t = EigenTensor::From(*u); auto v_t = EigenTensor::From(*v); @@ -88,6 +86,23 @@ static inline void CalcMatrixSigmaAndNormWeight( Array1(h)); u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); } +} + +// CalcMatrixSigmaAndNormWeight will not update u and v +template +static inline void CalcMatrixSigmaAndNormWeight( + Tensor* sigma, const Tensor* u, const Tensor* v, + Tensor* weight, const int power_iters, + const float eps, const framework::ExecutionContext& ctx) { + auto& place = *ctx.template device_context().eigen_device(); + auto blas = math::GetBlas(ctx); + auto sigma_t = EigenTensor::From(*sigma); + auto weight_t = EigenTensor::From(*weight); + auto u_t = EigenTensor::From(*u); + + const int h = weight->dims()[0]; + const int w = weight->dims()[1]; + Tensor weight_v; weight_v.mutable_data({h, 1}, ctx.GetPlace()); blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); @@ -109,6 +124,8 @@ class SpectralNormKernel : public framework::OpKernel { auto u = ctx.Input("U"); auto v = ctx.Input("V"); auto out = ctx.Output("Out"); + auto u_out = ctx.Output("UOut"); + auto v_out = ctx.Output("VOut"); int dim = ctx.Attr("dim"); int power_iters = ctx.Attr("power_iters"); @@ -144,11 +161,13 @@ class SpectralNormKernel : public framework::OpKernel { Tensor sigma; sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - TensorCopySync(*u, ctx.GetPlace(), &uu); - TensorCopySync(*v, ctx.GetPlace(), &vv); + TensorCopySync(*u, ctx.GetPlace(), u_out); + TensorCopySync(*v, ctx.GetPlace(), v_out); + UpdateUandV( + &(u_out->Resize({h, 1})), &(v_out->Resize({w, 1})), &weight_mat, + power_iters, eps, ctx); CalcMatrixSigmaAndNormWeight( - &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, + &sigma, &(u_out->Resize({h, 1})), &(v_out->Resize({w, 1})), &weight_mat, power_iters, eps, ctx); if (dim != 0) { @@ -180,8 +199,8 @@ class SpectralNormGradKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(ctx); auto weight = ctx.Input("Weight"); - auto u = ctx.Input("U"); - auto v = ctx.Input("V"); + auto u_out = ctx.Input("UOut"); + auto v_out = ctx.Input("VOut"); auto out_grad = ctx.Input(framework::GradVarName("Out")); auto weight_grad = ctx.Output(framework::GradVarName("Weight")); @@ -189,8 +208,12 @@ class SpectralNormGradKernel : public framework::OpKernel { int power_iters = ctx.Attr("power_iters"); float eps = ctx.Attr("eps"); - const int h = u->dims()[0]; - const int w = v->dims()[0]; + const int h = u_out->dims()[0]; + const int w = v_out->dims()[0]; + + Tensor u_mat, v_mat; + TensorCopySync(*u_out, ctx.GetPlace(), &u_mat); + TensorCopySync(*v_out, ctx.GetPlace(), &v_mat); Tensor weight_mat, out_grad_mat; auto dims = weight->dims(); @@ -225,16 +248,14 @@ class SpectralNormGradKernel : public framework::OpKernel { Tensor sigma; sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - TensorCopySync(*u, ctx.GetPlace(), &uu); - TensorCopySync(*v, ctx.GetPlace(), &vv); + CalcMatrixSigmaAndNormWeight( - &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, + &sigma, &(u_mat.Resize({h, 1})), &(v_mat.Resize({w, 1})), &weight_mat, power_iters, eps, ctx); Tensor uv; uv.mutable_data({h, w}, ctx.GetPlace()); - blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, + blas.MatMul(u_mat.Resize({h, 1}), false, v_mat.Resize({w, 1}), false, T(1), &uv, T(0)); Tensor weight_grad_mat; diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index a14c3a81c121758ed90450cd5eb5990f3f7739e1..eefe4072ed928db9549b5024e2ca6adb76d5500c 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1375,7 +1375,7 @@ class BatchNorm(layers.Layer): outputs = { "Y": [batch_norm_out], - "MeanOut": [mean_out], + "MeanOut": [], "VarianceOut": [variance_out], "SavedMean": [saved_mean], "SavedVariance": [saved_variance] @@ -3031,9 +3031,11 @@ class SpectralNorm(layers.Layer): dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0. power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1. eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12. + fix_state(bool, optional): whether to update the two vectors `u` and `v`. Default: True. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". + Returns: None @@ -3055,10 +3057,12 @@ class SpectralNorm(layers.Layer): dim=0, power_iters=1, eps=1e-12, + fix_state=True, dtype='float32'): super(SpectralNorm, self).__init__() self._power_iters = power_iters self._eps = eps + self._fix_state = fix_state self._dim = dim self._dtype = dtype @@ -3080,10 +3084,31 @@ class SpectralNorm(layers.Layer): default_initializer=Normal(0., 1.)) self.weight_v.stop_gradient = True + if fix_state: + self.out_weight_u = self.create_parameter( + attr=ParamAttr(), + shape=[h], + dtype=self._dtype, + default_initializer=Normal(0., 1.)) + self.out_weight_u.stop_gradient = True + + self.out_weight_v = self.create_parameter( + attr=ParamAttr(), + shape=[w], + dtype=self._dtype, + default_initializer=Normal(0., 1.)) + self.out_weight_v.stop_gradient = True + else: + self.out_weight_u = self.weight_u + self.out_weight_v = self.weight_v + def forward(self, weight): check_variable_and_dtype(weight, "weight", ['float32', 'float64'], 'SpectralNorm') - inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} + inputs = { + 'Weight': weight, 'U': self.weight_u, 'V': self.weight_v, + 'UOut': self.out_weight_u, 'VOut': self.out_weight_v, + } out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type="spectral_norm", diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index 7dd0c7625983ed01ceb9e803b886c45f91840e5d..227ed1d2ff02c836572df7da2ede9f4131d01f0d 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -44,7 +44,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps): u = u / (u_norm + eps) sigma = (u * np.matmul(weight_mat, v)).sum() - return weight / sigma + return weight / sigma, u, v @skip_check_grad_ci( @@ -63,6 +63,7 @@ class TestSpectralNormOpNoGrad(OpTest): "dim": self.dim, "power_iters": self.power_iters, "eps": self.eps, + "fix_state": self.fix_state, } self.inputs = { @@ -71,9 +72,9 @@ class TestSpectralNormOpNoGrad(OpTest): "V": v, } - output = spectral_norm(weight, u, v, self.dim, self.power_iters, - self.eps) - self.outputs = {"Out": output} + output, new_u, new_v = spectral_norm(weight, u, v, self.dim, self.power_iters, + self.eps) + self.outputs = {"Out": output, "UOut": new_u, "VOut": new_v} def test_check_output(self): self.check_output() @@ -85,6 +86,7 @@ class TestSpectralNormOpNoGrad(OpTest): self.dim = 0 self.power_iters = 5 self.eps = 1e-12 + self.fix_state = True @skip_check_grad_ci( @@ -99,6 +101,7 @@ class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad): self.dim = 1 self.power_iters = 10 self.eps = 1e-12 + self.fix_state = True class TestSpectralNormOp(TestSpectralNormOpNoGrad): @@ -115,6 +118,7 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad): self.dim = 0 self.power_iters = 0 self.eps = 1e-12 + self.fix_state = True class TestSpectralNormOp2(TestSpectralNormOp): @@ -125,6 +129,7 @@ class TestSpectralNormOp2(TestSpectralNormOp): self.dim = 1 self.power_iters = 0 self.eps = 1e-12 + self.fix_state = True class TestSpectralNormOpError(unittest.TestCase):