diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index a6addb2e6f46d87f7becb3b00f2087b83a558692..1d47a10d56bc22e375df29bd278181b7701e1dad 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -9,12 +9,14 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/spectral_norm_op.h" - #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/ternary.h" + namespace paddle { namespace operators { @@ -24,82 +26,6 @@ class SpectralNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); - - auto dim_weight = ctx->GetInputDim("Weight"); - auto rank_weight = dim_weight.size(); - PADDLE_ENFORCE_GE(rank_weight, - 2, - platform::errors::InvalidArgument( - "The rank of Input(Weights) should be greater equal " - "than 2, but received Weight rank(%d)", - rank_weight)); - PADDLE_ENFORCE_LE(rank_weight, - 5, - platform::errors::InvalidArgument( - "The rank of Input(Weights) should be less equal " - "than 5, but received Weight rank(%d)", - rank_weight)); - - int dim = ctx->Attrs().Get("dim"); - int power_iters = ctx->Attrs().Get("power_iters"); - auto dim_valid = dim == 0 || dim == 1; - PADDLE_ENFORCE_EQ( - dim_valid, - true, - platform::errors::InvalidArgument( - "Attr(dim) can only be 0 or 1, but received %d", dim)); - PADDLE_ENFORCE_GE( - power_iters, - 0, - platform::errors::InvalidArgument( - "Attr(power_iters) should be greater equal then 0, but received %d", - power_iters)); - - int h = dim_weight[dim]; - int w = 1; - for (int i = 0; i < rank_weight; i++) { - if (i != dim) { - w *= dim_weight[i]; - } - } - auto dim_u = ctx->GetInputDim("U"); - auto dim_v = ctx->GetInputDim("V"); - - if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) { - PADDLE_ENFORCE_EQ(dim_u[0], - h, - platform::errors::InvalidArgument( - "Input(U) dimension[0] should be equal to " - "Input(Weight) dimension[Attr(dim)], but received " - "U dimension[0](%d) != Weight dimension[%d](%d)", - dim_u[0], - dim, - h)); - } - - if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) { - PADDLE_ENFORCE_EQ( - dim_v[0], - w, - platform::errors::InvalidArgument( - "Input(V) dimension[0] should be equal to the product of " - "Input(Weight) dimension except dimension[Attr(dim)], but " - "received V dimension[0](%d) != product of Input(Weight) " - "dimension(%d)", - dim_v[0], - w)); - } - - ctx->SetOutputDim("Out", dim_weight); - ctx->ShareLoD("Weight", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -219,26 +145,6 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("Weight"), "Input", "Weight", "SpectralNormGrad"); - OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNormGrad"); - OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNormGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "SpectralNormGrad"); - - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Out")), - true, - platform::errors::NotFound("Input(Out@GRAD) should not be null")); - auto dim_x = ctx->GetInputDim("Weight"); - if (ctx->HasOutput(framework::GradVarName("Weight"))) { - ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -250,15 +156,20 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm, + SpectralNormInferMetaFunctor, + PD_INFER_META(phi::SpectralNormInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm_grad, + SpectralNormGradInferMetaFunctor, + PD_INFER_META(phi::SpectralNormGradInferMeta)); + REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker, ops::SpectralNormGradOpMaker, - ops::SpectralNormGradOpMaker); -REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad); -REGISTER_OP_CPU_KERNEL(spectral_norm, - ops::SpectralNormKernel, - ops::SpectralNormKernel); -REGISTER_OP_CPU_KERNEL(spectral_norm_grad, - ops::SpectralNormGradKernel, - ops::SpectralNormGradKernel); + ops::SpectralNormGradOpMaker, + SpectralNormInferMetaFunctor); +REGISTER_OPERATOR(spectral_norm_grad, + ops::SpectralNormOpGrad, + SpectralNormGradInferMetaFunctor); diff --git a/paddle/fluid/operators/spectral_norm_op.cu b/paddle/fluid/operators/spectral_norm_op.cu deleted file mode 100644 index ea90e3b4c122b00d5bfe13617e48a9bbe0ee8395..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/spectral_norm_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/spectral_norm_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - spectral_norm, - ops::SpectralNormKernel, - ops::SpectralNormKernel); -REGISTER_OP_CUDA_KERNEL( - spectral_norm_grad, - ops::SpectralNormGradKernel, - ops::SpectralNormGradKernel); diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h deleted file mode 100644 index ffe8a40c35a468075bebf9d5022cf15ac7c0175c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/spectral_norm_op.h +++ /dev/null @@ -1,299 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -using EigenTensor = framework::EigenTensor; -using Tensor = framework::Tensor; - -using Array1 = Eigen::DSizes; -using Array2 = Eigen::DSizes; -using IndexPair = Eigen::IndexPair; - -template -static inline void TransCompute(const int rank, - const Tensor& in, - Tensor* out, - const std::vector& perm, - const DeviceContext& dev_ctx) { - if (rank <= 1 || rank > 5) { - PADDLE_THROW(paddle::platform::errors::Fatal( - "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", - rank)); - } - - switch (rank) { - case 2: - phi::funcs::Transpose trans2; - trans2(dev_ctx, in, out, perm); - break; - case 3: - phi::funcs::Transpose trans3; - trans3(dev_ctx, in, out, perm); - break; - case 4: - phi::funcs::Transpose trans4; - trans4(dev_ctx, in, out, perm); - break; - case 5: - phi::funcs::Transpose trans5; - trans5(dev_ctx, in, out, perm); - break; - default: - break; - } -} - -template -static inline void CalcMatrixSigmaAndNormWeight( - Tensor* sigma, - Tensor* u, - Tensor* v, - Tensor* weight, - const int power_iters, - const float eps, - const framework::ExecutionContext& ctx) { - auto& place = *ctx.template device_context().eigen_device(); - auto blas = phi::funcs::GetBlas(ctx); - auto sigma_t = EigenTensor::From(*sigma); - auto weight_t = EigenTensor::From(*weight); - auto u_t = EigenTensor::From(*u); - auto v_t = EigenTensor::From(*v); - - const int h = weight->dims()[0]; - const int w = weight->dims()[1]; - - for (int i = 0; i < power_iters; i++) { - // V = W^T * U / ||W^T * U||_2 - blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); - auto v_t_norm = - v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( - Array1(w)); - v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); - // U = W^T * V / ||W^T * V||_2 - blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); - auto u_t_norm = - u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( - Array1(h)); - u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); - } - Tensor weight_v; - weight_v.mutable_data({h, 1}, ctx.GetPlace()); - blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); - auto weight_v_t = EigenTensor::From(weight_v); - sigma_t.device(place) = (u_t * weight_v_t) - .sum() - .eval() - .reshape(Array2(1, 1)) - .broadcast(Array2(h, w)); - weight_t.device(place) = weight_t / sigma_t; -} - -template -class SpectralNormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - auto weight = ctx.Input("Weight"); - auto u = ctx.Input("U"); - auto v = ctx.Input("V"); - auto out = ctx.Output("Out"); - - int dim = ctx.Attr("dim"); - int power_iters = ctx.Attr("power_iters"); - float eps = ctx.Attr("eps"); - - const int h = u->dims()[0]; - const int w = v->dims()[0]; - - Tensor weight_mat; - auto dims = weight->dims(); - const int rank = dims.size(); - std::vector real_dims; - if (dim != 0) { - std::vector perm; - perm.push_back(dim); - real_dims.push_back(dims[dim]); - for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); - } - } - weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); - } else { - for (int i = 0; i < rank; i++) { - real_dims.push_back(i); - } - paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); - } - weight_mat = weight_mat.Resize({h, w}); - - Tensor sigma; - sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); - paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); - CalcMatrixSigmaAndNormWeight(&sigma, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &weight_mat, - power_iters, - eps, - ctx); - - if (dim != 0) { - std::vector perm; - for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } - } - out->mutable_data(dims, ctx.GetPlace()); - TransCompute( - rank, - weight_mat.Resize(phi::make_ddim(real_dims)), - out, - perm, - dev_ctx); - } else { - paddle::framework::TensorCopySync( - weight_mat.Resize(dims), ctx.GetPlace(), out); - } - } -}; - -template -class SpectralNormGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& place = *ctx.template device_context().eigen_device(); - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(ctx); - auto weight = ctx.Input("Weight"); - auto u = ctx.Input("U"); - auto v = ctx.Input("V"); - auto out_grad = ctx.Input(framework::GradVarName("Out")); - auto weight_grad = ctx.Output(framework::GradVarName("Weight")); - - int dim = ctx.Attr("dim"); - int power_iters = ctx.Attr("power_iters"); - float eps = ctx.Attr("eps"); - - const int h = u->dims()[0]; - const int w = v->dims()[0]; - - Tensor weight_mat, out_grad_mat; - auto dims = weight->dims(); - const int rank = dims.size(); - std::vector real_dims; - if (dim != 0) { - std::vector perm; - perm.push_back(dim); - real_dims.push_back(dims[dim]); - for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); - } - } - weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - out_grad_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); - TransCompute( - rank, *out_grad, &out_grad_mat, perm, dev_ctx); - } else { - for (int i = 0; i < rank; i++) { - real_dims.push_back(i); - } - paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); - paddle::framework::TensorCopySync( - *out_grad, ctx.GetPlace(), &out_grad_mat); - } - weight_mat = weight_mat.Resize({h, w}); - out_grad_mat = out_grad_mat.Resize({h, w}); - - Tensor sigma; - sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); - paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); - CalcMatrixSigmaAndNormWeight(&sigma, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &weight_mat, - power_iters, - eps, - ctx); - - Tensor uv; - uv.mutable_data({h, w}, ctx.GetPlace()); - blas.MatMul( - uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); - - Tensor weight_grad_mat; - weight_grad_mat.mutable_data({h, w}, ctx.GetPlace()); - auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); - auto weight_mat_t = EigenTensor::From(weight_mat); - auto out_grad_mat_t = EigenTensor::From(out_grad_mat); - auto sigma_t = EigenTensor::From(sigma); - auto uv_t = EigenTensor::From(uv); - weight_mat_t.device(place) = - weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); - weight_grad_mat_t.device(place) = - out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / - sigma_t; - - if (dim != 0) { - std::vector perm; - for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } - } - weight_grad->mutable_data(dims, ctx.GetPlace()); - TransCompute( - rank, - weight_grad_mat.Resize(phi::make_ddim(real_dims)), - weight_grad, - perm, - dev_ctx); - } else { - paddle::framework::TensorCopySync( - weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index a36a48d505fc510e21c784ebeb4cf09ac493267c..45b347af9321a720c2fbea3096d405b0afe0b8a6 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2152,6 +2152,16 @@ use_gpudnn : true backward : softmax_grad +- api : spectral_norm + args : (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) + output : Tensor + infer_meta : + func : SpectralNormInferMeta + kernel : + func : spectralnorm + data_type : weight + backward : spectral_norm_grad + - api : split args : (Tensor x, IntArray num_or_sections, Scalar(int) axis) output : Tensor[] diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 77fc49f29b2f3fec52d7d3e19d8bcd88cef375d4..95bcebb92bbf6241070e8242b42b52086317b296 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2037,6 +2037,16 @@ func : softmax_grad use_gpudnn : true +- backward_api : spectral_norm_grad + forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out) + args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps) + output : Tensor(weight_grad) + infer_meta : + func : SpectralNormGradInferMeta + kernel : + func : spectral_norm_grad + data_type : out_grad + - backward_api : split_grad forward : split (Tensor x, IntArray num_or_sections, Scalar axis) -> Tensor[](out) args : (Tensor[] out_grad, Scalar axis = -1) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index bfae939820ead6c2c88d5b21ad3b25f34670aaa3..2b377d6727c3699d8ba6fdf48e10763557b3c598 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -685,6 +685,21 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, } } +void SpectralNormGradInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + const MetaTensor& out_grad, + int dim, + int power_iters, + float eps, + MetaTensor* weight_grad) { + auto dim_x = weight.dims(); + if (weight_grad) { + weight_grad->set_dims(dim_x); + weight_grad->set_dtype(out_grad.dtype()); + } +} + void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 16d9b82e0644233d1c1deb027acb94766e7d08bd..91b9f007bc9e3f7bc2107ef9a3a9a206a2dcc71f 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -288,6 +288,15 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, MetaTensor* x_grad, MetaTensor* updates_grad); +void SpectralNormGradInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + const MetaTensor& out_grad, + int dim, + int power_iters, + float eps, + MetaTensor* weight_grad); + void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index b83febb24b5f1bce907f657be1ce20085150caac..7dc799d989577d4d54ec779ef92d2cdf9fed96d0 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1088,6 +1088,83 @@ void ScatterNdAddInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void SpectralNormInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + int dim, + int power_iters, + float eps, + MetaTensor* out, + MetaConfig config) { + auto dim_weight = weight.dims(); + auto rank_weight = dim_weight.size(); + PADDLE_ENFORCE_GE(rank_weight, + 2, + errors::InvalidArgument( + "The rank of Input(Weights) should be greater equal " + "than 2, but received Weight rank(%d)", + rank_weight)); + PADDLE_ENFORCE_LE( + rank_weight, + 5, + errors::InvalidArgument("The rank of Input(Weights) should be less equal " + "than 5, but received Weight rank(%d)", + rank_weight)); + + auto dim_valid = dim == 0 || dim == 1; + PADDLE_ENFORCE_EQ(dim_valid, + true, + errors::InvalidArgument( + "Attr(dim) can only be 0 or 1, but received %d", dim)); + PADDLE_ENFORCE_GE( + power_iters, + 0, + errors::InvalidArgument( + "Attr(power_iters) should be greater equal then 0, but received %d", + power_iters)); + + int h = dim_weight[dim]; + int w = 1; + for (int i = 0; i < rank_weight; i++) { + if (i != dim) { + w *= dim_weight[i]; + } + } + auto dim_u = u.dims(); + auto dim_v = v.dims(); + + if (config.is_runtime || (dim_u[0] > 0 && h > 0)) { + PADDLE_ENFORCE_EQ(dim_u[0], + h, + errors::InvalidArgument( + "Input(U) dimension[0] should be equal to " + "Input(Weight) dimension[Attr(dim)], but received " + "U dimension[0](%d) != Weight dimension[%d](%d)", + dim_u[0], + dim, + h)); + } + + if (config.is_runtime || (dim_v[0] > 0 && w > 0)) { + PADDLE_ENFORCE_EQ( + dim_v[0], + w, + errors::InvalidArgument( + "Input(V) dimension[0] should be equal to the product of " + "Input(Weight) dimension except dimension[Attr(dim)], but " + "received V dimension[0](%d) != product of Input(Weight) " + "dimension(%d)", + dim_v[0], + w)); + } + + if (out) { + out->set_dims(dim_weight); + out->set_dtype(weight.dtype()); + out->share_lod(weight); + } +} + void ViterbiDecodeInferMeta(const MetaTensor& input, const MetaTensor& transition, const MetaTensor& length, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 329c6e13a5022e77858c33c9dcdfaa9fb6470831..6cf9b169d6236c3819e2bb5cd5884b00c5bc0838 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -170,6 +170,15 @@ void ScatterNdAddInferMeta(const MetaTensor& x, const MetaTensor& updates, MetaTensor* out); +void SpectralNormInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + int dim, + int power_iters, + float eps, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ViterbiDecodeInferMeta(const MetaTensor& input, const MetaTensor& transition, const MetaTensor& length, diff --git a/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5603eb0f7cb9127ee84725c8916cc357fbec4f01 --- /dev/null +++ b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" + +PD_REGISTER_KERNEL(spectral_norm_grad, + CPU, + ALL_LAYOUT, + phi::SpectralNormGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/spectral_norm_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ff25365d1c16d7af21d744285c2b5e7a40d1ec3 --- /dev/null +++ b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_kernel.h" + +PD_REGISTER_KERNEL( + spectral_norm, CPU, ALL_LAYOUT, phi::SpectralNormKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..75c82e90fc0591300d7ac852b6f2267c763b690c --- /dev/null +++ b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" + +PD_REGISTER_KERNEL(spectral_norm_grad, + GPU, + ALL_LAYOUT, + phi::SpectralNormGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/spectral_norm_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7709cf5da1b5a7375b6c5bb969b5442c0c85d8d7 --- /dev/null +++ b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_kernel.h" + +PD_REGISTER_KERNEL( + spectral_norm, GPU, ALL_LAYOUT, phi::SpectralNormKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..5bdb874bc89c474d161017445ecdac7411eb59d4 --- /dev/null +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -0,0 +1,130 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" + +namespace phi { + +template +void SpectralNormGradKernel(const Context& dev_ctx, + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + const DenseTensor& out_grad, + int dim, + int power_iters, + float eps, + DenseTensor* weight_grad) { + auto& place = *dev_ctx.eigen_device(); + auto blas = phi::funcs::GetBlas(dev_ctx); + + const int h = u.dims()[0]; + const int w = v.dims()[0]; + + DenseTensor weight_mat, out_grad_mat; + auto dims = weight.dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&weight_mat); + out_grad_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&out_grad_mat); + TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); + TransCompute2DTo5D( + dev_ctx, out_grad, rank, perm, &out_grad_mat); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), true, &out_grad_mat); + } + weight_mat = weight_mat.Resize({h, w}); + out_grad_mat = out_grad_mat.Resize({h, w}); + + DenseTensor sigma; + sigma.Resize(weight_mat.dims()); + dev_ctx.template Alloc(&sigma); + DenseTensor uu, vv; + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); + CalcMatrixSigmaAndNormWeight(dev_ctx, + &weight_mat, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &sigma, + power_iters, + eps); + + DenseTensor uv; + uv.Resize({h, w}); + dev_ctx.template Alloc(&uv); + blas.MatMul( + uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); + + DenseTensor weight_grad_mat; + weight_grad_mat.Resize({h, w}); + dev_ctx.template Alloc(&weight_grad_mat); + auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); + auto weight_mat_t = EigenTensor::From(weight_mat); + auto out_grad_mat_t = EigenTensor::From(out_grad_mat); + auto sigma_t = EigenTensor::From(sigma); + auto uv_t = EigenTensor::From(uv); + weight_mat_t.device(place) = + weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); + weight_grad_mat_t.device(place) = + out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / + sigma_t; + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + weight_grad->Resize(dims); + dev_ctx.template Alloc(weight_grad); + TransCompute2DTo5D( + dev_ctx, + weight_grad_mat.Resize(phi::make_ddim(real_dims)), + rank, + perm, + weight_grad); + } else { + phi::Copy(dev_ctx, + weight_grad_mat.Resize(dims), + dev_ctx.GetPlace(), + true, + weight_grad); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..57c5c69a63d614486060a7927730a34bf122d68e --- /dev/null +++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h @@ -0,0 +1,177 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using IndexPair = Eigen::IndexPair; + +template +static inline void TransCompute2DTo5D(const Context& dev_ctx, + const DenseTensor& in, + const int rank, + const std::vector& perm, + DenseTensor* out) { + if (rank <= 1 || rank > 5) { + PADDLE_THROW(phi::errors::Fatal( + "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", + rank)); + } + + switch (rank) { + case 2: + phi::funcs::Transpose trans2; + trans2(dev_ctx, in, out, perm); + break; + case 3: + phi::funcs::Transpose trans3; + trans3(dev_ctx, in, out, perm); + break; + case 4: + phi::funcs::Transpose trans4; + trans4(dev_ctx, in, out, perm); + break; + case 5: + phi::funcs::Transpose trans5; + trans5(dev_ctx, in, out, perm); + break; + default: + break; + } +} + +template +static inline void CalcMatrixSigmaAndNormWeight(const Context& dev_ctx, + DenseTensor* weight, + DenseTensor* u, + DenseTensor* v, + DenseTensor* sigma, + const int power_iters, + const float eps) { + auto& place = *dev_ctx.eigen_device(); + auto blas = funcs::GetBlas(dev_ctx); + auto sigma_t = EigenTensor::From(*sigma); + auto weight_t = EigenTensor::From(*weight); + auto u_t = EigenTensor::From(*u); + auto v_t = EigenTensor::From(*v); + + const int h = weight->dims()[0]; + const int w = weight->dims()[1]; + + for (int i = 0; i < power_iters; i++) { + // V = W^T * U / ||W^T * U||_2 + blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); + auto v_t_norm = + v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(w)); + v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); + // U = W^T * V / ||W^T * V||_2 + blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); + auto u_t_norm = + u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(h)); + u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); + } + DenseTensor weight_v; + weight_v.Resize({h, 1}); + dev_ctx.template Alloc(&weight_v); + blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); + auto weight_v_t = EigenTensor::From(weight_v); + sigma_t.device(place) = (u_t * weight_v_t) + .sum() + .eval() + .reshape(Array2(1, 1)) + .broadcast(Array2(h, w)); + weight_t.device(place) = weight_t / sigma_t; +} + +template +void SpectralNormKernel(const Context& dev_ctx, + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + int dim, + int power_iters, + float eps, + DenseTensor* out) { + const int h = u.dims()[0]; + const int w = v.dims()[0]; + + DenseTensor weight_mat; + auto dims = weight.dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&weight_mat); + TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); + } + weight_mat = weight_mat.Resize({h, w}); + + DenseTensor sigma; + sigma.Resize(weight_mat.dims()); + dev_ctx.template Alloc(&sigma); + DenseTensor uu, vv; + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); + CalcMatrixSigmaAndNormWeight(dev_ctx, + &weight_mat, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &sigma, + power_iters, + eps); + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + out->Resize(dims); + dev_ctx.template Alloc(out); + TransCompute2DTo5D( + dev_ctx, weight_mat.Resize(phi::make_ddim(real_dims)), rank, perm, out); + } else { + phi::Copy(dev_ctx, weight_mat.Resize(dims), dev_ctx.GetPlace(), true, out); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/spectral_norm_grad_kernel.h b/paddle/phi/kernels/spectral_norm_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..504cfba4b95e766fcc349b69637e00ade372e434 --- /dev/null +++ b/paddle/phi/kernels/spectral_norm_grad_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SpectrumNormGradKernel(const Context& dev_ctx, + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + const DenseTensor& out_grad, + int dim, + int power_iters, + float eps, + DenseTensor* weight_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/spectral_norm_kernel.h b/paddle/phi/kernels/spectral_norm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..26b1699898ea6e09bf1700617aebfc341daefa0d --- /dev/null +++ b/paddle/phi/kernels/spectral_norm_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SpectrumNormKernel(const Context& dev_ctx, + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + int dim, + int power_iters, + float eps, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/spectral_norm_sig.cc b/paddle/phi/ops/compat/spectral_norm_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea11df24881aaf8bacbcc59ffbb35c4f9e3090ad --- /dev/null +++ b/paddle/phi/ops/compat/spectral_norm_sig.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SpectralNormOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("spectral_norm", + {"Weight", "U", "V"}, + {"dim", "power_iters", "eps"}, + {"Out"}); +} + +KernelSignature SpectralNormGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("spectral_norm_grad", + {"Weight", "U", "V", "Out@GRAD"}, + {"dim", "power_iters", "eps"}, + {"Weight@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(spectral_norm, phi::SpectralNormOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(spectral_norm_grad, + phi::SpectralNormGradOpArgumentMapping);