From 3bf1ae9b599db1615d56323443deb39976bdb16f Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 20 Feb 2019 21:39:01 +0800 Subject: [PATCH] add spectral_norm forwarn kenel --- paddle/fluid/operators/spectral_norm_op.cc | 143 ++++++++++++++++++ paddle/fluid/operators/spectral_norm_op.h | 128 ++++++++++++++++ .../tests/unittests/test_spectral_norm_op.py | 64 ++++++++ 3 files changed, 335 insertions(+) create mode 100644 paddle/fluid/operators/spectral_norm_op.cc create mode 100644 paddle/fluid/operators/spectral_norm_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_spectral_norm_op.py diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc new file mode 100644 index 0000000000..e7fbf4e6ec --- /dev/null +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -0,0 +1,143 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/spectral_norm_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SpectralNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) of SpectralNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("U"), + "Input(U) of SpectralNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("V"), + "Input(V) of SpectralNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SpectralNormOp should not be null."); + + auto dim_weight = ctx->GetInputDim("Weight"); + auto weight_dimsize = dim_weight.size(); + PADDLE_ENFORCE(weight_dimsize >= 2 && weight_dimsize <= 5, + "The size of dims of Input(Weights) can only be 2, 3," + "4, 5 for fc, conv1d, conv2d, conv3d layers."); + + int dim = ctx->Attrs().Get("dim"); + int power_iters = ctx->Attrs().Get("power_iters"); + PADDLE_ENFORCE(dim >= 0 && dim < weight_dimsize - 1, + "Attr(dim) should be larger equal 0 and less then the" + "size of dims of Input(Weights) - 1,"); + PADDLE_ENFORCE(power_iters >= 0, + "Attr(power_iters) should be larger equal then 0"); + + ctx->SetOutputDim("Out", dim_weight); + ctx->ShareLoD("Weight", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("Weight")->type(), + ctx.GetPlace()); + } +}; + +class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Weight", + "The input weight tensor of spectral_norm operator, " + "This can be a 2-D, 3-D, 4-D, 5-D tensor which is the" + "weights of fc, conv1d, conv2d, conv3d layer."); + AddInput("U", + "The weight_u tensor of spectral_norm operator, " + "This can be a 1-D tensor in shape [H, 1]," + "H is the 1st dimentions of Weight after reshape" + "corresponding by Attr(dim)."); + AddInput("V", + "The weight_u tensor of spectral_norm operator, " + "This can be a 1-D tensor in shape [W, 1]," + "W is the 2nd dimentions of Weight after reshape" + "corresponding by Attr(dim)."); + AddOutput("Out", + "The output weight tensor of spectral_norm operator, " + "This tensor is in same shape with Input(Weight)."); + + AddAttr("dim", + "dimension corresponding to number of outputs," + "default 0 for fc layer, and 1 for conv1d, conv2d, conv3d" + "layers") + .SetDefault(0); + AddAttr("power_iters", + "number of power iterations to calculate" + "spectral norm, default is 1.") + .SetDefault(1); + AddAttr("eps", + "epsilob for numerical stability in" + "calculating norms") + .SetDefault(1e-12); + + AddComment(R"DOC( + This operator samples input X to given output shape by using specified + + + + )DOC"); + } +}; + +class SpectralNormOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("Weight"); + if (ctx->HasOutput(framework::GradVarName("Weight"))) { + ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("Weight")->type(), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad); +REGISTER_OP_CPU_KERNEL( + spectral_norm, + ops::SpectralNormKernel, + ops::SpectralNormKernel); +REGISTER_OP_CPU_KERNEL( + spectral_norm_grad, + ops::SpectralNormGradKernel, + ops::SpectralNormGradKernel); diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h new file mode 100644 index 0000000000..876dacf3bb --- /dev/null +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -0,0 +1,128 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +using EigenTensor = framework::EigenTensor; +using Tensor = framework::Tensor; + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using IndexPair = Eigen::IndexPair; + +static inline void ResizeWeight(Tensor* weight_mat, const int dim) { + auto weight_dims = weight_mat->dims(); + int h = 1; + int w = 1; + for (int i = 0; i < weight_dims.size(); i++) { + if (i <= dim) { + h *= weight_dims[i]; + } else { + w *= weight_dims[i]; + } + } + *weight_mat = weight_mat->Resize({h, w}); +} + +template +static inline void CalcMatrixSigmaAndNormWeight( + Tensor* sigma, Tensor* u, Tensor* v, Tensor* weight, const int power_iters, + const float eps, const framework::ExecutionContext& ctx) { + auto& place = *ctx.template device_context().eigen_device(); + auto sigma_t = EigenTensor::From(*sigma); + auto weight_t = EigenTensor::From(*weight); + auto u_t = EigenTensor::From(*u); + auto v_t = EigenTensor::From(*v); + + const int h = weight->dims()[0]; + const int w = weight->dims()[1]; + + Eigen::array perm = {1, 0}; + Eigen::array product_dims = {IndexPair(1, 0)}; + auto weight_trans_t = weight_t.shuffle(perm); + LOG(ERROR) << "weight: " << weight_t; + LOG(ERROR) << "weight_trans: " << weight_trans_t; + for (int i = 0; i < power_iters; i++) { + v_t.device(place) = weight_trans_t.contract(u_t, product_dims); + LOG(ERROR) << "iter v: " << v_t; + auto v_t_norm = + v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(w)); + LOG(ERROR) << "iter v_norm: " << v_t_norm; + v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); + LOG(ERROR) << "iter norm v: " << v_t; + u_t.device(place) = weight_t.contract(v_t, product_dims); + LOG(ERROR) << "iter u: " << u_t; + auto u_t_norm = + u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(h)); + u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); + LOG(ERROR) << "iter norm u: " << u_t; + } + LOG(ERROR) << "h" << h << "w" << w; + LOG(ERROR) << "u: " << u_t; + LOG(ERROR) << "v: " << v_t; + LOG(ERROR) << "weight_v: " << weight_t.contract(v_t, product_dims); + sigma_t.device(place) = (u_t * weight_t.contract(v_t, product_dims)) + .sum() + .eval() + .reshape(Array2(1, 1)) + .broadcast(Array2(h, w)); + LOG(ERROR) << "weight: " << weight_t; + LOG(ERROR) << "sigma: " << sigma_t; + weight_t.device(place) = weight_t / sigma_t; +} + +template +class SpectralNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto weight = ctx.Input("Weight"); + auto u = ctx.Input("U"); + auto v = ctx.Input("V"); + auto out = ctx.Output("Out"); + + int dim = ctx.Attr("dim"); + int power_iters = ctx.Attr("power_iters"); + float eps = ctx.Attr("eps"); + + Tensor weight_mat; + TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + ResizeWeight(&weight_mat, dim); + + Tensor sigma; + sigma.mutable_data(weight->dims(), ctx.GetPlace()); + Tensor uu, vv; + TensorCopySync(*u, ctx.GetPlace(), &uu); + TensorCopySync(*v, ctx.GetPlace(), &vv); + CalcMatrixSigmaAndNormWeight( + &sigma, &uu, &vv, &weight_mat, power_iters, eps, ctx); + TensorCopySync(weight_mat, ctx.GetPlace(), out); + } +}; + +template +class SpectralNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py new file mode 100644 index 0000000000..2d7ff16aa6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + +from paddle.fluid import core + + +class TestSpectralNormOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'spectral_norm' + # weight = np.random.random(self.weight_shape).astype('float32') + # u = np.random.random(self.u_shape).astype('float32') + # v = np.random.random(self.u_shape).astype('float32') + weight = np.ones(self.weight_shape).astype('float32') + weight[1, :] = 2. + u = np.ones(self.u_shape).astype('float32') + v = np.ones(self.v_shape).astype('float32') + + self.attrs = { + "dim": self.dim, + "power_iters": self.power_iters, + "eps": self.eps, + } + + self.inputs = { + "Weight": weight, + "U": u, + "V": v, + } + + output = weight + self.outputs = {"Out": weight, } + + def test_check_output(self): + self.check_output() + + def initTestCase(self): + self.weight_shape = (2, 3) + self.u_shape = (2, ) + self.v_shape = (3, ) + self.dim = 0 + self.power_iters = 1 + self.eps = 1e-12 + + +if __name__ == "__main__": + unittest.main() -- GitLab