spectral_norm_op.cc 6.6 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
D
dengkaipeng 已提交
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

Z
zhhsplendid 已提交
12 13
#include <memory>

14
#include "paddle/fluid/framework/infershape_utils.h"
D
dengkaipeng 已提交
15 16
#include "paddle/fluid/framework/op_registry.h"

17 18 19
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"

D
dengkaipeng 已提交
20 21 22 23 24 25 26 27 28 29 30 31
namespace paddle {
namespace operators {

using framework::Tensor;

class SpectralNormOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
32 33
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace());
D
dengkaipeng 已提交
34 35 36 37 38 39 40 41
  }
};

class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("Weight",
             "The input weight tensor of spectral_norm operator, "
D
dengkaipeng 已提交
42
             "This can be a 2-D, 3-D, 4-D, 5-D tensor which is the "
K
Kaipeng Deng 已提交
43 44
             "weights of fc, conv1d, conv2d, conv3d layer. "
             "The data type is float32 or float64.");
D
dengkaipeng 已提交
45 46 47
    AddInput("U",
             "The weight_u tensor of spectral_norm operator, "
             "This can be a 1-D tensor in shape [H, 1],"
T
tianshuo78520a 已提交
48
             "H is the 1st dimensions of Weight after reshape"
49 50
             "corresponding by Attr(dim). As for Attr(dim) = 1"
             "in conv2d layer with weight shape [M, C, K1, K2]"
D
dengkaipeng 已提交
51
             "Weight will be reshape to [C, M*K1*K2], U will"
52
             "be in shape [C, 1].");
D
dengkaipeng 已提交
53
    AddInput("V",
54
             "The weight_v tensor of spectral_norm operator, "
D
dengkaipeng 已提交
55
             "This can be a 1-D tensor in shape [W, 1], "
T
tianshuo78520a 已提交
56
             "W is the 2nd dimensions of Weight after reshape "
D
dengkaipeng 已提交
57 58 59
             "corresponding by Attr(dim). As for Attr(dim) = 1 "
             "in conv2d layer with weight shape [M, C, K1, K2] "
             "Weight will be reshape to [C, M*K1*K2], V will "
60
             "be in shape [M*K1*K2, 1].");
D
dengkaipeng 已提交
61 62 63 64 65
    AddOutput("Out",
              "The output weight tensor of spectral_norm operator, "
              "This tensor is in same shape with Input(Weight).");

    AddAttr<int>("dim",
D
dengkaipeng 已提交
66 67
                 "The index of dimension which should be permuted "
                 "to the first before reshaping Input(Weight) to "
D
dengkaipeng 已提交
68 69
                 "matrix, it should be set as 0 if Input(Weight) is "
                 "the weight of fc layer, and should be set as 1 if "
D
dengkaipeng 已提交
70 71
                 "Input(Weight) is the weight of conv layer, "
                 "default 0.")
D
dengkaipeng 已提交
72 73
        .SetDefault(0);
    AddAttr<int>("power_iters",
D
dengkaipeng 已提交
74 75
                 "number of power iterations to calculate "
                 "spectral norm, default 1.")
D
dengkaipeng 已提交
76 77
        .SetDefault(1);
    AddAttr<float>("eps",
D
dengkaipeng 已提交
78
                   "epsilon for numerical stability in "
K
Kaipeng Deng 已提交
79 80 81
                   "calculating norms, it will be added to "
                   "the denominator to aviod divide zero. "
                   "Default 1e-12.")
D
dengkaipeng 已提交
82 83 84
        .SetDefault(1e-12);

    AddComment(R"DOC(
D
dengkaipeng 已提交
85
          This layer calculates the spectral normalization value of weight of
86 87
          fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
          tensor.
D
dengkaipeng 已提交
88

89 90 91
          Spectral normalization stabilizes the training of critic in GANs
          (Generative Adversarial Networks). This layer rescaling weight tensor
          with spectral normalize value.
D
dengkaipeng 已提交
92

93
          For spectral normalization calculations, we rescaling weight
D
dengkaipeng 已提交
94
          tensor with :math:`\sigma`, while :math:`\sigma{\mathbf{W}}` is
95

D
dengkaipeng 已提交
96
            $$\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \\frac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$$
97

D
dengkaipeng 已提交
98
          We calculate :math:`\sigma{\mathbf{W}}` through power iterations as
99

D
dengkaipeng 已提交
100
            $$
101
            \mathbf{v} = \mathbf{W}^{T} \mathbf{u}
D
dengkaipeng 已提交
102 103 104 105 106
            $$
            $$
            \mathbf{v} = \\frac{\mathbf{v}}{\|\mathbf{v}\|_2}
            $$
            $$
107
            \mathbf{u} = \mathbf{W}^{T} \mathbf{v}
D
dengkaipeng 已提交
108 109 110 111
            $$
            $$
            \mathbf{u} = \\frac{\mathbf{u}}{\|\mathbf{u}\|_2}
            $$
112

D
dengkaipeng 已提交
113
          And :math:`\sigma` should be
114

D
dengkaipeng 已提交
115
            $$\sigma{\mathbf{W}} = \mathbf{u}^{T} \mathbf{W} \mathbf{v}$$
116 117 118

          For details of spectral normalization, please refer to paper: 
          `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
D
dengkaipeng 已提交
119 120 121 122
         )DOC");
  }
};

H
hong 已提交
123 124
template <typename T>
class SpectralNormGradOpMaker : public framework::SingleGradOpMaker<T> {
Z
zhhsplendid 已提交
125
 public:
H
hong 已提交
126
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Z
zhhsplendid 已提交
127 128

 protected:
129
  void Apply(GradOpPtr<T> op) const override {
Z
zhhsplendid 已提交
130 131
    op->SetType("spectral_norm_grad");

H
hong 已提交
132 133 134 135
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetInput("Weight", this->Input("Weight"));
    op->SetInput("U", this->Input("U"));
    op->SetInput("V", this->Input("V"));
Z
zhhsplendid 已提交
136

H
hong 已提交
137
    op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
Z
zhhsplendid 已提交
138

H
hong 已提交
139
    op->SetAttrMap(this->Attrs());
Z
zhhsplendid 已提交
140 141 142
  }
};

D
dengkaipeng 已提交
143 144 145 146 147 148 149
class SpectralNormOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
150 151
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace());
D
dengkaipeng 已提交
152 153 154 155 156 157 158
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
159 160 161 162 163 164 165 166

DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm,
                            SpectralNormInferMetaFunctor,
                            PD_INFER_META(phi::SpectralNormInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm_grad,
                            SpectralNormGradInferMetaFunctor,
                            PD_INFER_META(phi::SpectralNormGradInferMeta));

167 168 169
REGISTER_OPERATOR(spectral_norm,
                  ops::SpectralNormOp,
                  ops::SpectralNormOpMaker,
H
hong 已提交
170
                  ops::SpectralNormGradOpMaker<paddle::framework::OpDesc>,
171 172 173 174 175
                  ops::SpectralNormGradOpMaker<paddle::imperative::OpBase>,
                  SpectralNormInferMetaFunctor);
REGISTER_OPERATOR(spectral_norm_grad,
                  ops::SpectralNormOpGrad,
                  SpectralNormGradInferMetaFunctor);