diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 52af3ce51ba67c2b58c5e79c18c8d554e3c4b68c..afbff1e13cf7dc4a8f9f4bd9252e2194e0862c93 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -128,6 +128,7 @@ paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'par paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '013795af319e2e86d3506741941078ee')) paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e')) paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b')) +paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b')) paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, True, False)), ('document', 'bce1b75e3d95b75cacd1099655cbb3c3')) paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88')) paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '6148b6a555cbfb62fdcd030d8982c18c')) diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..357d055756523cd83bf0e4b30719155b32c65974 --- /dev/null +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -0,0 +1,197 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/spectral_norm_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SpectralNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) of SpectralNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("U"), + "Input(U) of SpectralNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("V"), + "Input(V) of SpectralNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SpectralNormOp should not be null."); + + auto dim_weight = ctx->GetInputDim("Weight"); + auto rank_weight = dim_weight.size(); + PADDLE_ENFORCE(rank_weight >= 2 && rank_weight <= 5, + "The rank of Input(Weights) can only be 2, 3," + "4, 5 for fc, conv1d, conv2d, conv3d layers."); + + int dim = ctx->Attrs().Get("dim"); + int power_iters = ctx->Attrs().Get("power_iters"); + PADDLE_ENFORCE(dim == 0 || dim == 1, "Attr(dim) can only be 0 or 1"); + PADDLE_ENFORCE(power_iters >= 0, + "Attr(power_iters) should be larger equal then 0"); + + int h = dim_weight[dim]; + int w = 1; + for (int i = 0; i < rank_weight; i++) { + if (i != dim) { + w *= dim_weight[i]; + } + } + auto dim_u = ctx->GetInputDim("U"); + auto dim_v = ctx->GetInputDim("V"); + PADDLE_ENFORCE_EQ(dim_u[0], h, + "Input(U) dims[0] should be equal to " + "Input(Weight) dims[Attr(dim)]"); + PADDLE_ENFORCE_EQ( + dim_v[0], w, + "Input(V) dims[0] should be equal to " + "the product of Input(Weight) dims except dims[Attr(dim)]"); + + ctx->SetOutputDim("Out", dim_weight); + ctx->ShareLoD("Weight", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("Weight")->type(), + ctx.GetPlace()); + } +}; + +class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Weight", + "The input weight tensor of spectral_norm operator, " + "This can be a 2-D, 3-D, 4-D, 5-D tensor which is the " + "weights of fc, conv1d, conv2d, conv3d layer."); + AddInput("U", + "The weight_u tensor of spectral_norm operator, " + "This can be a 1-D tensor in shape [H, 1]," + "H is the 1st dimentions of Weight after reshape" + "corresponding by Attr(dim). As for Attr(dim) = 1" + "in conv2d layer with weight shape [M, C, K1, K2]" + "Weight will be reshape to [C, M*K1*K2], U will" + "be in shape [C, 1]."); + AddInput("V", + "The weight_v tensor of spectral_norm operator, " + "This can be a 1-D tensor in shape [W, 1], " + "W is the 2nd dimentions of Weight after reshape " + "corresponding by Attr(dim). As for Attr(dim) = 1 " + "in conv2d layer with weight shape [M, C, K1, K2] " + "Weight will be reshape to [C, M*K1*K2], V will " + "be in shape [M*K1*K2, 1]."); + AddOutput("Out", + "The output weight tensor of spectral_norm operator, " + "This tensor is in same shape with Input(Weight)."); + + AddAttr("dim", + "The index of dimension which should be permuted " + "to the first before reshaping Input(Weight) to " + "matrix, it should be set as 0 if Input(Weight) is " + "the weight of fc layer, and should be set as 1 if " + "Input(Weight) is the weight of conv layer, " + "default 0.") + .SetDefault(0); + AddAttr("power_iters", + "number of power iterations to calculate " + "spectral norm, default 1.") + .SetDefault(1); + AddAttr("eps", + "epsilon for numerical stability in " + "calculating norms") + .SetDefault(1e-12); + + AddComment(R"DOC( + This layer calculates the spectral normalization value of weight of + fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D + tensor. + + Spectral normalization stabilizes the training of critic in GANs + (Generative Adversarial Networks). This layer rescaling weight tensor + with spectral normalize value. + + For spectral normalization calculations, we rescaling weight + tensor with :math:`\sigma`, while :math:`\sigma{\mathbf{W}}` is + + $$\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \\frac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$$ + + We calculate :math:`\sigma{\mathbf{W}}` through power iterations as + + $$ + \mathbf{v} = \mathbf{W}^{T} \mathbf{u} + $$ + $$ + \mathbf{v} = \\frac{\mathbf{v}}{\|\mathbf{v}\|_2} + $$ + $$ + \mathbf{u} = \mathbf{W}^{T} \mathbf{v} + $$ + $$ + \mathbf{u} = \\frac{\mathbf{u}}{\|\mathbf{u}\|_2} + $$ + + And :math:`\sigma` should be + + $$\sigma{\mathbf{W}} = \mathbf{u}^{T} \mathbf{W} \mathbf{v}$$ + + For details of spectral normalization, please refer to paper: + `Spectral Normalization `_ . + )DOC"); + } +}; + +class SpectralNormOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("Weight"); + if (ctx->HasOutput(framework::GradVarName("Weight"))) { + ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("Weight")->type(), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad); +REGISTER_OP_CPU_KERNEL( + spectral_norm, + ops::SpectralNormKernel, + ops::SpectralNormKernel); +REGISTER_OP_CPU_KERNEL( + spectral_norm_grad, + ops::SpectralNormGradKernel, + ops::SpectralNormGradKernel); diff --git a/paddle/fluid/operators/spectral_norm_op.cu b/paddle/fluid/operators/spectral_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ea90e3b4c122b00d5bfe13617e48a9bbe0ee8395 --- /dev/null +++ b/paddle/fluid/operators/spectral_norm_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/spectral_norm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + spectral_norm, + ops::SpectralNormKernel, + ops::SpectralNormKernel); +REGISTER_OP_CUDA_KERNEL( + spectral_norm_grad, + ops::SpectralNormGradKernel, + ops::SpectralNormGradKernel); diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..eb48e3b7840e18efe809540dd697f243a0a63a52 --- /dev/null +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -0,0 +1,273 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +using EigenTensor = framework::EigenTensor; +using Tensor = framework::Tensor; + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using IndexPair = Eigen::IndexPair; + +template +static inline void TransCompute(const int rank, const Tensor& in, Tensor* out, + const std::vector& perm, + const DeviceContext& dev_ctx) { + if (rank <= 1 || rank > 5) { + PADDLE_THROW("Invalid weight rank."); + } + + switch (rank) { + case 2: + math::Transpose trans2; + trans2(dev_ctx, in, out, perm); + break; + case 3: + math::Transpose trans3; + trans3(dev_ctx, in, out, perm); + break; + case 4: + math::Transpose trans4; + trans4(dev_ctx, in, out, perm); + break; + case 5: + math::Transpose trans5; + trans5(dev_ctx, in, out, perm); + break; + default: + break; + } +} + +template +static inline void CalcMatrixSigmaAndNormWeight( + Tensor* sigma, Tensor* u, Tensor* v, Tensor* weight, const int power_iters, + const float eps, const framework::ExecutionContext& ctx) { + auto& place = *ctx.template device_context().eigen_device(); + auto blas = math::GetBlas(ctx); + auto sigma_t = EigenTensor::From(*sigma); + auto weight_t = EigenTensor::From(*weight); + auto u_t = EigenTensor::From(*u); + auto v_t = EigenTensor::From(*v); + + const int h = weight->dims()[0]; + const int w = weight->dims()[1]; + + for (int i = 0; i < power_iters; i++) { + // V = W^T * U / ||W^T * U||_2 + blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); + auto v_t_norm = + v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(w)); + v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); + // U = W^T * V / ||W^T * V||_2 + blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); + auto u_t_norm = + u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(h)); + u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); + } + Tensor weight_v; + weight_v.mutable_data({h, 1}, ctx.GetPlace()); + blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); + auto weight_v_t = EigenTensor::From(weight_v); + sigma_t.device(place) = (u_t * weight_v_t) + .sum() + .eval() + .reshape(Array2(1, 1)) + .broadcast(Array2(h, w)); + weight_t.device(place) = weight_t / sigma_t; +} + +template +class SpectralNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + auto weight = ctx.Input("Weight"); + auto u = ctx.Input("U"); + auto v = ctx.Input("V"); + auto out = ctx.Output("Out"); + + int dim = ctx.Attr("dim"); + int power_iters = ctx.Attr("power_iters"); + float eps = ctx.Attr("eps"); + + const int h = u->dims()[0]; + const int w = v->dims()[0]; + + Tensor weight_mat; + auto dims = weight->dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.mutable_data(framework::make_ddim(real_dims), + ctx.GetPlace()); + TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + } + weight_mat = weight_mat.Resize({h, w}); + + Tensor sigma; + sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); + Tensor uu, vv; + TensorCopySync(*u, ctx.GetPlace(), &uu); + TensorCopySync(*v, ctx.GetPlace(), &vv); + CalcMatrixSigmaAndNormWeight( + &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, + power_iters, eps, ctx); + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + out->mutable_data(dims, ctx.GetPlace()); + TransCompute( + rank, weight_mat.Resize(framework::make_ddim(real_dims)), out, perm, + dev_ctx); + } else { + TensorCopySync(weight_mat.Resize(dims), ctx.GetPlace(), out); + } + } +}; + +template +class SpectralNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(ctx); + auto weight = ctx.Input("Weight"); + auto u = ctx.Input("U"); + auto v = ctx.Input("V"); + auto out_grad = ctx.Input(framework::GradVarName("Out")); + auto weight_grad = ctx.Output(framework::GradVarName("Weight")); + + int dim = ctx.Attr("dim"); + int power_iters = ctx.Attr("power_iters"); + float eps = ctx.Attr("eps"); + + const int h = u->dims()[0]; + const int w = v->dims()[0]; + + Tensor weight_mat, out_grad_mat; + auto dims = weight->dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.mutable_data(framework::make_ddim(real_dims), + ctx.GetPlace()); + out_grad_mat.mutable_data(framework::make_ddim(real_dims), + ctx.GetPlace()); + TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); + TransCompute(rank, *out_grad, &out_grad_mat, perm, + dev_ctx); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat); + } + weight_mat = weight_mat.Resize({h, w}); + out_grad_mat = out_grad_mat.Resize({h, w}); + + Tensor sigma; + sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); + Tensor uu, vv; + TensorCopySync(*u, ctx.GetPlace(), &uu); + TensorCopySync(*v, ctx.GetPlace(), &vv); + CalcMatrixSigmaAndNormWeight( + &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, + power_iters, eps, ctx); + + Tensor uv; + uv.mutable_data({h, w}, ctx.GetPlace()); + blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, + T(0)); + + Tensor weight_grad_mat; + weight_grad_mat.mutable_data({h, w}, ctx.GetPlace()); + auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); + auto weight_mat_t = EigenTensor::From(weight_mat); + auto out_grad_mat_t = EigenTensor::From(out_grad_mat); + auto sigma_t = EigenTensor::From(sigma); + auto uv_t = EigenTensor::From(uv); + weight_mat_t.device(place) = + weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); + weight_grad_mat_t.device(place) = + out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / + sigma_t; + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + weight_grad->mutable_data(dims, ctx.GetPlace()); + TransCompute( + rank, weight_grad_mat.Resize(framework::make_ddim(real_dims)), + weight_grad, perm, dev_ctx); + } else { + TensorCopySync(weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index efb400ccc6d43df44325dc7ef88c14afe4b704c3..0f4fe1b559e1e79bace82e13f0f8828b869d69b7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -94,6 +94,7 @@ __all__ = [ 'multiplex', 'layer_norm', 'group_norm', + 'spectral_norm', 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', @@ -3346,6 +3347,98 @@ def group_norm(input, return helper.append_activation(group_norm_out) +@templatedoc() +def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): + """ + **Spectral Normalization Layer** + + This layer calculates the spectral normalization value of weight parameters of + fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D + Parameters. Calculations are showed as follows. + + Step 1: + Generate vector U in shape of [H], and V in shape of [W]. + While H is the :attr:`dim` th dimension of the input weights, + and W is the product result of remaining dimensions. + + Step 2: + :attr:`power_iters` shoule be a positive interger, do following + calculations with U and V for :attr:`power_iters` rounds. + + .. math:: + + \mathbf{v} := \\frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2} + + \mathbf{u} := \\frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2} + + Step 3: + Calculate :math:`\sigma(\mathbf{W})` and normalize weight values. + + .. math:: + + \sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v} + + \mathbf{W} = \\frac{\mathbf{W}}{\sigma(\mathbf{W})} + + + Refer to `Spectral Normalization `_ . + + Args: + weight(${weight_type}): ${weight_comment} + dim(int): ${dim_comment} + power_iters(int): ${power_iters_comment} + eps(float): ${eps_comment} + name (str): The name of this layer. It is optional. + + Returns: + Variable: A tensor variable of weight parameters after spectral normalization. + + Examples: + + >>> weight = fluid.layers.data(name='weight', shape=[8, 32, 32], + >>> dtype='float32') + >>> x = fluid.layers.spectral_norm(weight=data, dim=1, power_iters=2) + """ + helper = LayerHelper('spectral_norm', **locals()) + dtype = weight.dtype + + # create intput and parameters + inputs = {'Weight': weight} + input_shape = weight.shape + h = input_shape[dim] + w = np.prod(input_shape) // h + + u = helper.create_parameter( + attr=ParamAttr(), + shape=[h], + dtype=dtype, + default_initializer=Normal(0., 1.)) + u.stop_gradient = True + inputs['U'] = u + v = helper.create_parameter( + attr=ParamAttr(), + shape=[w], + dtype=dtype, + default_initializer=Normal(0., 1.)) + inputs['V'] = v + v.stop_gradient = True + + # create output + out = helper.create_variable(dtype=dtype) + + helper.append_op( + type="spectral_norm", + inputs=inputs, + outputs={"Out": out, }, + attrs={ + "dim": dim, + "power_iters": power_iters, + "eps": eps, + }) + + return out + + def conv2d_transpose(input, num_filters, output_size=None, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 30194f8cacfea2361ffe4afe537287a261cf470b..ff49c1be979a2076952963ec54302fb68361eedf 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1035,6 +1035,19 @@ class TestBook(unittest.TestCase): print(str(program)) + def test_spectral_norm(self): + program = Program() + with program_guard(program): + weight = layers.data( + name='weight', + shape=[2, 3, 32, 32], + dtype="float32", + append_batch_size=False) + out = layers.spectral_norm(weight, dim=1, power_iters=1) + self.assertIsNotNone(out) + + print(str(program)) + def test_shuffle_channel(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e431bcce571798893ccc96c74fd9972b657f3e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -0,0 +1,122 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + +from paddle.fluid import core + + +def spectral_norm(weight, u, v, dim, power_iters, eps): + shape = weight.shape + weight_mat = weight.copy() + h = shape[dim] + w = np.prod(shape) // h + if dim != 0: + perm = [dim] + [d for d in range(len(shape)) if d != dim] + weight_mat = weight_mat.transpose(perm) + weight_mat = weight_mat.reshape((h, w)) + + u = u.reshape((h, 1)) + v = v.reshape((w, 1)) + for i in range(power_iters): + v = np.matmul(weight_mat.T, u) + v_norm = np.sqrt((v * v).sum()) + v = v / (v_norm + eps) + u = np.matmul(weight_mat, v) + u_norm = np.sqrt((u * u).sum()) + u = u / (u_norm + eps) + + sigma = (u * np.matmul(weight_mat, v)).sum() + return weight / sigma + + +class TestSpectralNormOpNoGrad(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'spectral_norm' + weight = np.random.random(self.weight_shape).astype('float32') + u = np.random.normal(0., 1., self.u_shape).astype('float32') + v = np.random.normal(0., 1., self.v_shape).astype('float32') + + self.attrs = { + "dim": self.dim, + "power_iters": self.power_iters, + "eps": self.eps, + } + + self.inputs = { + "Weight": weight, + "U": u, + "V": v, + } + + output = spectral_norm(weight, u, v, self.dim, self.power_iters, + self.eps) + self.outputs = {"Out": output} + + def test_check_output(self): + self.check_output() + + def initTestCase(self): + self.weight_shape = (2, 3) + self.u_shape = (2, ) + self.v_shape = (3, ) + self.dim = 0 + self.power_iters = 5 + self.eps = 1e-12 + + +class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad): + def initTestCase(self): + self.weight_shape = (2, 3, 3, 3) + self.u_shape = (3, ) + self.v_shape = (18, ) + self.dim = 1 + self.power_iters = 10 + self.eps = 1e-12 + + +class TestSpectralNormOp(TestSpectralNormOpNoGrad): + def test_check_grad_ignore_uv(self): + self.check_grad( + ['Weight'], + 'Out', + no_grad_set=set(["U", "V"]), + max_relative_error=0.1) + + def initTestCase(self): + self.weight_shape = (2, 3) + self.u_shape = (2, ) + self.v_shape = (3, ) + self.dim = 0 + self.power_iters = 0 + self.eps = 1e-12 + + +class TestSpectralNormOp2(TestSpectralNormOp): + def initTestCase(self): + self.weight_shape = (2, 3, 3, 3) + self.u_shape = (3, ) + self.v_shape = (18, ) + self.dim = 1 + self.power_iters = 0 + self.eps = 1e-12 + + +if __name__ == "__main__": + unittest.main()