From 28927209fefd9269eea2dea39839d89e85c3a8d4 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Thu, 30 Mar 2023 16:36:48 +0800 Subject: [PATCH] add autogen code support for spectral_norm (#52145) * add autogen code support for spectral_norm * bug fixed * fix PR-CI-Static-Check fail --- paddle/fluid/operators/spectral_norm_op.cc | 173 --------------------- paddle/phi/api/yaml/backward.yaml | 10 ++ paddle/phi/api/yaml/legacy_backward.yaml | 10 -- paddle/phi/api/yaml/legacy_ops.yaml | 10 -- paddle/phi/api/yaml/op_compat.yaml | 7 + paddle/phi/api/yaml/ops.yaml | 10 ++ paddle/phi/ops/compat/spectral_norm_sig.cc | 39 ----- python/paddle/static/nn/common.py | 17 +- 8 files changed, 39 insertions(+), 237 deletions(-) delete mode 100644 paddle/fluid/operators/spectral_norm_op.cc delete mode 100644 paddle/phi/ops/compat/spectral_norm_sig.cc diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc deleted file mode 100644 index 85bd8676652..00000000000 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ /dev/null @@ -1,173 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" - -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/ternary.h" - -namespace paddle { -namespace operators { - -class SpectralNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey( - OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); - } -}; - -class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Weight", - "The input weight tensor of spectral_norm operator, " - "This can be a 2-D, 3-D, 4-D, 5-D tensor which is the " - "weights of fc, conv1d, conv2d, conv3d layer. " - "The data type is float32 or float64."); - AddInput("U", - "The weight_u tensor of spectral_norm operator, " - "This can be a 1-D tensor in shape [H, 1]," - "H is the 1st dimensions of Weight after reshape" - "corresponding by Attr(dim). As for Attr(dim) = 1" - "in conv2d layer with weight shape [M, C, K1, K2]" - "Weight will be reshape to [C, M*K1*K2], U will" - "be in shape [C, 1]."); - AddInput("V", - "The weight_v tensor of spectral_norm operator, " - "This can be a 1-D tensor in shape [W, 1], " - "W is the 2nd dimensions of Weight after reshape " - "corresponding by Attr(dim). As for Attr(dim) = 1 " - "in conv2d layer with weight shape [M, C, K1, K2] " - "Weight will be reshape to [C, M*K1*K2], V will " - "be in shape [M*K1*K2, 1]."); - AddOutput("Out", - "The output weight tensor of spectral_norm operator, " - "This tensor is in same shape with Input(Weight)."); - - AddAttr("dim", - "The index of dimension which should be permuted " - "to the first before reshaping Input(Weight) to " - "matrix, it should be set as 0 if Input(Weight) is " - "the weight of fc layer, and should be set as 1 if " - "Input(Weight) is the weight of conv layer, " - "default 0.") - .SetDefault(0); - AddAttr("power_iters", - "number of power iterations to calculate " - "spectral norm, default 1.") - .SetDefault(1); - AddAttr("eps", - "epsilon for numerical stability in " - "calculating norms, it will be added to " - "the denominator to aviod divide zero. " - "Default 1e-12.") - .SetDefault(1e-12); - - AddComment(R"DOC( - This layer calculates the spectral normalization value of weight of - fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D - tensor. - - Spectral normalization stabilizes the training of critic in GANs - (Generative Adversarial Networks). This layer rescaling weight tensor - with spectral normalize value. - - For spectral normalization calculations, we rescaling weight - tensor with :math:`\sigma`, while :math:`\sigma{\mathbf{W}}` is - - $$\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \\frac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$$ - - We calculate :math:`\sigma{\mathbf{W}}` through power iterations as - - $$ - \mathbf{v} = \mathbf{W}^{T} \mathbf{u} - $$ - $$ - \mathbf{v} = \\frac{\mathbf{v}}{\|\mathbf{v}\|_2} - $$ - $$ - \mathbf{u} = \mathbf{W}^{T} \mathbf{v} - $$ - $$ - \mathbf{u} = \\frac{\mathbf{u}}{\|\mathbf{u}\|_2} - $$ - - And :math:`\sigma` should be - - $$\sigma{\mathbf{W}} = \mathbf{u}^{T} \mathbf{W} \mathbf{v}$$ - - For details of spectral normalization, please refer to paper: - `Spectral Normalization `_ . - )DOC"); - } -}; - -template -class SpectralNormGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("spectral_norm_grad"); - - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetInput("Weight", this->Input("Weight")); - op->SetInput("U", this->Input("U")); - op->SetInput("V", this->Input("V")); - - op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight")); - - op->SetAttrMap(this->Attrs()); - } -}; - -class SpectralNormOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey( - OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm, - SpectralNormInferMetaFunctor, - PD_INFER_META(phi::SpectralNormInferMeta)); -DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm_grad, - SpectralNormGradInferMetaFunctor, - PD_INFER_META(phi::SpectralNormGradInferMeta)); - -REGISTER_OPERATOR(spectral_norm, - ops::SpectralNormOp, - ops::SpectralNormOpMaker, - ops::SpectralNormGradOpMaker, - ops::SpectralNormGradOpMaker, - SpectralNormInferMetaFunctor); -REGISTER_OPERATOR(spectral_norm_grad, - ops::SpectralNormOpGrad, - SpectralNormGradInferMetaFunctor); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 92beb701e5d..858b2d15609 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1557,6 +1557,16 @@ kernel : func : solve_grad +- backward_op : spectral_norm_grad + forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim = 0, int power_iters = 1, float eps=1e-12f) -> Tensor(out) + args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps) + output : Tensor(weight_grad) + infer_meta : + func : SpectralNormGradInferMeta + kernel : + func : spectral_norm_grad + data_type : weight + - backward_op : sqrt_double_grad forward : sqrt_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) args : (Tensor out, Tensor grad_x, Tensor grad_x_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 035d301589d..cfe7930c8c3 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1044,16 +1044,6 @@ func : softmax_grad composite : softmax_grad(out, out_grad, axis, x_grad) -- backward_op : spectral_norm_grad - forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out) - args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps) - output : Tensor(weight_grad) - infer_meta : - func : SpectralNormGradInferMeta - kernel : - func : spectral_norm_grad - data_type : out_grad - - backward_op : split_grad forward : split (Tensor x, IntArray num_or_sections, Scalar axis) -> Tensor[](out) args : (Tensor[] out_grad, Scalar axis = -1) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 6cf0d1640fc..8fa2243f30e 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1400,16 +1400,6 @@ inplace : (x -> out) backward : softmax_grad -- op : spectral_norm - args : (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) - output : Tensor - infer_meta : - func : SpectralNormInferMeta - kernel : - func : spectral_norm - data_type : weight - backward : spectral_norm_grad - - op : split args : (Tensor x, IntArray sections, Scalar(int) axis) output : Tensor[]{sections.size()} diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 1111d14351a..15ca92c78b4 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1795,6 +1795,13 @@ outputs : out : Out +- op : spectral_norm + backward : spectral_norm_grad + inputs : + {weight : Weight, u : U, v : V} + outputs : + out : Out + - op : sqrt backward : sqrt_grad, sqrt_double_grad (sqrt_grad_grad) inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index f80a0a770fc..31329287d58 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1499,6 +1499,16 @@ data_type : x backward : solve_grad +- op : spectral_norm + args : (Tensor weight, Tensor u, Tensor v, int dim = 0, int power_iters = 1, float eps = 1e-12f) + output : Tensor + infer_meta : + func : SpectralNormInferMeta + kernel : + func : spectral_norm + data_type : weight + backward : spectral_norm_grad + - op : sqrt args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/ops/compat/spectral_norm_sig.cc b/paddle/phi/ops/compat/spectral_norm_sig.cc deleted file mode 100644 index ea11df24881..00000000000 --- a/paddle/phi/ops/compat/spectral_norm_sig.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature SpectralNormOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("spectral_norm", - {"Weight", "U", "V"}, - {"dim", "power_iters", "eps"}, - {"Out"}); -} - -KernelSignature SpectralNormGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("spectral_norm_grad", - {"Weight", "U", "V", "Out@GRAD"}, - {"dim", "power_iters", "eps"}, - {"Weight@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(spectral_norm, phi::SpectralNormOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(spectral_norm_grad, - phi::SpectralNormGradOpArgumentMapping); diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 05e76601fea..37fe41624a4 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -3376,7 +3376,6 @@ def row_conv(input, future_context_size, param_attr=None, act=None): return helper.append_activation(out) -@templatedoc() def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): r""" :api_attr: Static Graph @@ -3417,10 +3416,18 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): Refer to `Spectral Normalization `_ . Args: - weight(Tensor): ${weight_comment} - dim(int): ${dim_comment} - power_iters(int): ${power_iters_comment} - eps(float): ${eps_comment} + weight(Tensor): The input weight tensor of spectral_norm operator, + This can be a 2-D, 3-D, 4-D, 5-D tensor which is the + weights of fc, conv1d, conv2d, conv3d layer. + The data type is float32 or float64. + dim(int): The index of dimension which should be permuted + to the first before reshaping Input(Weight) to + matrix, it should be set as 0 if Input(Weight) is + the weight of fc layer, and should be set as 1 if + Input(Weight) is the weight of conv layer, default 0. + power_iters(int): number of power iterations to calculate spectral norm, default 1. + eps(float): epsilon for numerical stability in calculating norms, it will be added to + the denominator to aviod divide zero. Default 1e-12. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. -- GitLab